Tensorflow:使用Dataset.from_tensor_slices()的非常大的估计器日志

时间:2018-05-16 20:55:50

标签: python tensorflow

我一直在研究mnist估算代码(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/layers/cnn_mnist.py) 在使用此代码进行训练或15万步之后,估算器生成的日志大小为31M。 (每个重量检查点为13M,图形定义为5M)。

在修补代码时,我使用tf.data.Dataset.from_tensor_slices()编写了我自己的train_input_fn。 我的代码在这里:

def my_train_input_fn():
    mnist = tf.contrib.learn.datasets.load_dataset("mnist")
    images = mnist.train.images  # Returns np.array
    labels = np.asarray(mnist.train.labels, dtype=np.int32)
    dataset = tf.data.Dataset.from_tensor_slices(
        ({"x": images}, labels))
    dataset = dataset.shuffle(50000).repeat().batch(100)

    return dataset

并且,我的日志,甚至在训练的一个步骤之前,仅在图形初始化之后,大小超过1,5G! (ckpt-meta为165M,每个events.out.tfevents和graph.pbtxt文件大约为600M)。

经过一番研究后我发现函数from_tensor_slices()不适合较大的数据集,因为它在执行图中创建了常量。

  

请注意,上面的代码段会嵌入功能和标签   TensorFlow图中的数组为tf.constant()操作。这个   适用于小型数据集,但浪费内存 - 因为   数组的内容将被复制多次---并且可以进入   tf.GraphDef协议缓冲区的2GB限制。

源: https://www.tensorflow.org/programmers_guide/datasets

但是mnist数据集的大小只有大约13M。那么为什么我的图形定义有600M,而不仅仅是那些作为常量嵌入的13M呢?为什么事件文件如此之大?

生成代码(https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/python/estimator/inputs/numpy_io.py)的原始数据集不会生成如此大的日志文件。我想这是因为队列的使用。但现在队列已被弃用,我们应该使用tf.Dataset而不是队列,对吗?从包含图像的文件(而不是TFRecord)创建此类数据集的正确方法是什么?我应该使用tf.data.FixedLengthRecordDataset吗?

1 个答案:

答案 0 :(得分:2)

我有一个类似的问题,我使用tf.data.Dataset.from_generator解决了 或先输入tf.data.Dataset.range,再输入tdata.map以获得特定值。

例如带有发电机

def generator():
   for sample in zip(*datasets_tuple):
     yield sample

dataset = tf.data.Dataset.from_generator(generator,
       output_types=output_types, output_shapes=output_shapes)