Tensorflow:连接多个tf.Dataset非常慢

时间:2018-08-31 07:39:31

标签: python tensorflow tensorflow-datasets

我正在使用Tensorflow 1.10

现在我不确定这是否是错误。

我一直在尝试合并大约100个我从多个tf.data.Dataset.from_generator生成的数据集。

for i in range(1, 100):
        dataset = dataset.concatenate(
            tf.data.Dataset.from_generator(gens[i], (tf.int8, tf.int32), output_shapes=(
                (256, 256), (1))))
        print(i)
 print("before iterator")
 iterator = dataset.make_one_shot_iterator()
 print("after iterator")

运行make_one_shot_iterator()会花费很长时间。

有人知道解决办法吗?

编辑:

看起来像_make_dataset.add_to_graph(ops.get_default_graph()) 似乎一遍又一遍地调用,导致该函数的数百万次调用。 (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py函数make_one_shot_iterator第162行)

1 个答案:

答案 0 :(得分:0)

对于这样的多个张量或生成器,运行$1_DB;实际上并不是最好的选择。

更好的方法是使用concatenate https://www.tensorflow.org/api_docs/python/tf/data/Dataset#flat_map。我确实做了一段时间的示例更新,以展示如何将其用于多个张量或文件。