如何使用提供的代码Tensorflow在迭代期间仅使用imagenet训练集的一半?

时间:2020-06-15 07:01:42

标签: python tensorflow imagenet

code由Tensorflow提供,这是在进行Trianing时如何获取ImageNet TFrecord文件的方法:

import tensorflow as tf       
import imagenet_data
import image_processing

imagenet_data_train = imagenet_data.ImagenetData('train')
train_images, train_labels =  image_processing.inputs(imagenet_data_train, batch_size=256, num_preprocess_threads=16)

coord = tf.train.Coordinator()
threads = []
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
   threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))


with tf.Session() as sess:      

    sess.run(tf.global_variables_initializer())

    try:
        for i in range(1000):    
            image_batch, label_batch = sess.run([train_images, train_labels ])

    finally:
            coord.request_stop()
            coord.join(threads)

现在我只想使用一半的训练数据(可能是Tfreocd文件中的前60万个数据)在训练期间进行迭代,我应该设置什么?

1 个答案:

答案 0 :(得分:0)

我通过添加额外的参数train_set_number_rate修改ImagenetData train_set_number_rate类和修改方法data_files来解决此问题,该方法控制文件名列表传递给下一个函数。

请注意,以这种方式,我必须使用函数distorted_inputs而不是inputs,以确保在不同的模型中使用相同的火车集合部分。(这可能会产生不良影响但由于我已经使用inputs来训练原始网络,因此必须使用inputs来确保比较的正确性。)

 class ImagenetData(Dataset):

     def __init__(self, subset, train_set_number_rate = None):
        super(ImagenetData, self).__init__('ImageNet', subset)
        self.train_set_number_rate = train_set_number_rate
     def data_files(self):
        """Returns a python list of all (sharded) data subset files.

        Returns:
          python list of all (sharded) data set files.
        Raises:
          ValueError: if there are not data_files matching the subset.
        """
        tf_record_pattern = os.path.join(data_dir, '%s-*' % self.subset)
        data_files = tf.gfile.Glob(tf_record_pattern)
        if self.subset=='validation':
            assert(self.train_set_number_rate==None)
        elif self.subset=='train':
            if self.train_set_number_rate!=None:
                data_files = data_files[0:round(len(data_files)*self.train_set_number_rate)]
        if not data_files:
            print('No files found for dataset %s/%s at %s' % (self.name,
                                                              self.subset,
                                                              data_dir))

            self.download_message()
            exit(-1)
        return data_files