TFRecord仅选择功能的一小部分

时间:2018-11-11 19:25:18

标签: python tensorflow

我有一个tfrecord文件,其中应该有〜8000个图像补丁,全部为128x128x4。我编写了一个与在tensorflow网站上找到的输入函数大致相同的输入函数: https://www.tensorflow.org/guide/datasets#randomly_shuffling_input_data

我读取tfrecord的代码是:

def tfrecord_input_fn(fileName,
                  numEpochs=None,
                  shuffle=None,
                  batch_size=None):
dataset = tf.data.TFRecordDataset(fileName, compression_type='GZIP')
dataset = dataset.map(lambda example_proto: tf.parse_single_example(example_proto, featuresDict))
if shuffle == True:
    dataset = dataset.shuffle(buffer_size=(batch_size * 10))
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(numEpochs)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
return features

当我在网络上训练它时,我编写的所有内容似乎都可以正常工作...但是,我注意到,当我在会话中运行它时,它似乎只使用了很小一部分输入图片。我正在这样使用它:

with tf.Session() as sess:
    for ep in range(num_epochs):
        train_img_lab = tfrecord_input_fn(fileName=datapath, numEpochs=epoch_num, shuffle=True, train_batch_size=train_batch_size)

        a = sess.run(train_img_lab)
        image, label = create_np_arrray_from_tensor(a, train_batch_size)
        _, c = sess.run([train_op, train_loss], feed_dict= trainimg_placeholder: image, trainlabels_placeholder: label})

当我没有看到我期望的模型输出时,我写了一些东西将128X128图像保存到磁盘上。那是当我意识到不知何故我在每个时期的每批中使用相同的10张或15张图像(共8000张左右)。我怀疑这意味着我要么不理解我的tfrecord_input函数中的某些内容(也许是批处理还是重复函数?),要么我没有在会话中的正确位置调用它。有人可以帮助我了解我正在采取的步骤以及需要采取的步骤,以实际使用训练数据中的所有图像吗?

谢谢!

0 个答案:

没有答案