iterator.get_next()返回字节数组/ iterator.get_next()无法分配给多个值,除非他们急于执行

时间:2018-07-26 11:34:14

标签: python tensorflow tensorflow-datasets tfrecord

我在尝试创建tf.Dataset时遇到问题 通过tf.data.TFRecordDatasettfrecord文件中删除。

def parse_function(example_proto):
# Defaults are not specified since both keys are required. 
keys_to_features={
      'image': tf.FixedLenFeature([1024*1024],tf.int64),
      'label': tf.FixedLenFeature([1024*1024],tf.int64)
}
features = tf.parse_example([example_proto],keys_to_features)
label = features['label']
image = features['image']
label = tf.reshape(label,(1024,1024))
image = tf.reshape(image,(1024,1024))
return image,label

def make_batch(batch_size):
    filenames = ["train.tfrecords"]
    tf.data.TFRecordDataset(filenames).repeat()
    dataset.map(map_func=parse_function,num_parallel_calls=batch_size)
    dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    image , label  = iterator.get_next()
    return image , label

这导致了错误:

  

未启用急切执行时,张量对象不可迭代。要遍历此张量,请使用tf.map_fn。

所以我更改了:image , label = iterator.get_next() 到:next_elem = iterator.get_next()

有了这个,我可以执行以下代码:

with tf.Session() as sess: 
sess.run(tf.global_variables_initializer())
next_elem   = sess.run( make_batch(1))

但是,next_elem是字节数组,而不是形状为[[1024,1024],[1024,1024])的元组。

1 个答案:

答案 0 :(得分:1)

所以事实证明,错误只是我的误解。

dataset.map(map_func=parse_function,num_parallel_calls=batch_size)
dataset.batch(batch_size)

不操纵数据集本身,请参见:Iterator.get_next() returning a tensor of shape ()

您实际上必须再次将这些操作产生的数据集分配给数据集,如下所示: dataset = dataset.map(map_func=parse_function,num_parallel_calls=batch_size) dataset = dataset.batch(batch_size)

这实际上也解决了iterator.get_next()问题。因此,我将next_elem = iterator.get_next()更改为:image , label = iterator.get_next()

,并且以下代码可以按预期工作:with tf.Session() as sess: sess.run(tf.global_variables_initializer()) image , label = sess.run( make_batch(1))