我想将数据集迭代器传递给函数,但是该函数需要知道数据集的长度。在下面的示例中,我可以将len(datafiles)
传递到my_custom_fn()
函数中,但是我想知道是否能够从iterator
,{{1}中提取数据集的长度。 }或batch_x
类,这样我就不必将其添加为输入。
batch_y
谢谢!
编辑:在我的情况下,此解决方案不起作用:tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?
运行后
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
返回
tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])
我确实找到了适合我的解决方案。将iterator_scope添加到我的代码中,例如:
<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>
然后从with tf.name_scope('iter'):
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
内进行呼叫:
my_custom_fn
不确定这是否是最好的方法,但似乎可行。很乐意对此提出任何建议,因为它似乎有些怪异。
答案 0 :(得分:0)
iterator
的长度是未知的,直到您遍历为止。您可以将len(datafiles)
显式传递给函数,但是如果坚持数据的持久性,则可以简单地使函数成为实例方法,并将数据集的长度存储在my_custom_fn
的对象中是一种方法。
不幸的是,由于iterator
不存储任何内容,因此会动态生成数据。但是,正如在TensorFlow的源代码中发现的那样,存在一个“私有”变量_batch_size
,用于存储批量大小。您可以在此处查看源代码:TensorFlow source。