如何访问由tf.data.Dataset.list_files()收集的文件名?

时间:2018-07-03 20:36:24

标签: python tensorflow tensorflow-datasets

我正在使用

file_data = tf.data.Dataset.list_files("../*.png")

收集图像文件以在TensorFlow中进行训练,但想访问收集的文件名列表,以便执行标签查找。

调用sess.run([file_data])失败:

TypeError: Fetch argument <TensorSliceDataset shapes: (), types: tf.string> has invalid type <class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>, must be a string or Tensor. (Can not convert a TensorSliceDataset into a Tensor or Operation.)

我还有其他方法可以使用吗?

2 个答案:

答案 0 :(得分:3)

通过一些额外的实验,我找到了解决此问题的方法:

首先,将数据集转换为迭代器:

iterator_helper = file_data.make_one_shot_iterator()

然后,遍历tf会话中的元素:

with tf.Session() as sess:
    filename_temp = iterator_helper.get_next()
    print(sess.run[filename_temp])

答案 1 :(得分:2)

Dataset.list_files() API使用tf.matching_files() op列出与给定模式匹配的文件。您还可以使用该操作以tf.Tensor的形式获取文件列表,并将其直接传递给sess.run()

filenames_as_tensor = tf.matching_files("../*.png")
filenames_as_array = sess.run(filenames_as_tensor)

for filename in filenames_as_array:
  print(filename)