code由Tensorflow提供,这是在进行Trianing时如何获取ImageNet TFrecord文件的方法:
import tensorflow as tf
import imagenet_data
import image_processing
imagenet_data_train = imagenet_data.ImagenetData('train')
train_images, train_labels = image_processing.inputs(imagenet_data_train, batch_size=256, num_preprocess_threads=16)
coord = tf.train.Coordinator()
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
for i in range(1000):
image_batch, label_batch = sess.run([train_images, train_labels ])
finally:
coord.request_stop()
coord.join(threads)
现在我只想使用一半的训练数据(可能是Tfreocd文件中的前60万个数据)在训练期间进行迭代,我应该设置什么?
答案 0 :(得分:0)
我通过添加额外的参数train_set_number_rate
修改ImagenetData train_set_number_rate
类和修改方法data_files
来解决此问题,该方法控制文件名列表传递给下一个函数。
请注意,以这种方式,我必须使用函数distorted_inputs
而不是inputs
,以确保在不同的模型中使用相同的火车集合部分。(这可能会产生不良影响但由于我已经使用inputs
来训练原始网络,因此必须使用inputs
来确保比较的正确性。)
class ImagenetData(Dataset):
def __init__(self, subset, train_set_number_rate = None):
super(ImagenetData, self).__init__('ImageNet', subset)
self.train_set_number_rate = train_set_number_rate
def data_files(self):
"""Returns a python list of all (sharded) data subset files.
Returns:
python list of all (sharded) data set files.
Raises:
ValueError: if there are not data_files matching the subset.
"""
tf_record_pattern = os.path.join(data_dir, '%s-*' % self.subset)
data_files = tf.gfile.Glob(tf_record_pattern)
if self.subset=='validation':
assert(self.train_set_number_rate==None)
elif self.subset=='train':
if self.train_set_number_rate!=None:
data_files = data_files[0:round(len(data_files)*self.train_set_number_rate)]
if not data_files:
print('No files found for dataset %s/%s at %s' % (self.name,
self.subset,
data_dir))
self.download_message()
exit(-1)
return data_files