确定tf.data.Dataset Tensorflow中的记录数

时间:2018-09-10 19:43:58

标签: python tensorflow machine-learning deep-learning

我想将数据集迭代器传递给函数,但是该函数需要知道数据集的长度。在下面的示例中,我可以将len(datafiles)传递到my_custom_fn()函数中,但是我想知道是否能够从iterator,{{1}中提取数据集的长度。 }或batch_x类,这样我就不必将其添加为输入。

batch_y

谢谢!

编辑:在我的情况下,此解决方案不起作用:tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

运行后

dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

返回

tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])

我确实找到了适合我的解决方案。将iterator_scope添加到我的代码中,例如:

<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>

然后从with tf.name_scope('iter'): dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes) iterator = dataset.make_initializable_iterator() sess.run(iterator.initializer) [batch_x, batch_y] = iterator.get_next() value = my_custom_fn(batch_x, batch_y) # lots of other stuff 内进行呼叫:

my_custom_fn

不确定这是否是最好的方法,但似乎可行。很乐意对此提出任何建议,因为它似乎有些怪异。

1 个答案:

答案 0 :(得分:0)

iterator的长度是未知的,直到您遍历为止。您可以将len(datafiles)显式传递给函数,但是如果坚持数据的持久性,则可以简单地使函数成为实例方法,并将数据集的长度存储在my_custom_fn的对象中是一种方法。

不幸的是,由于iterator不存储任何内容,因此会动态生成数据。但是,正如在TensorFlow的源代码中发现的那样,存在一个“私有”变量_batch_size,用于存储批量大小。您可以在此处查看源代码:TensorFlow source