Tensorflow Dataset API将图形protobuff文件大小加倍

时间:2017-08-07 14:22:41

标签: python tensorflow skflow

摘要:使用新的tf.contrib.data.Dataset会使我的图表protobuff文件的大小加倍,并且我无法在Tensorboard中显示图形。

详细信息:

我正在尝试新的TensorFlow tf.contrib.data.Dataset功能以及tf.contrib.learn.Experiment框架。我的输入数据定义为input functions,它返回功能和标签的张量。

如果我使用tf.train.slice_input_producer函数创建输入函数,例如以下代码块(完整代码here),那么生成的graph.pbtxt文件为620M,.meta文件大小约为165M。

def train_inputs():
    with tf.name_scope('Training_data'):
        x = tf.constant(mnist.train.images.reshape([-1, 28, 28, 1]))
        y = tf.constant(mnist.train.labels)
        sliced_input = tf.train.slice_input_producer(
            tensor_list=[x, y], shuffle=True)
        return tf.train.shuffle_batch(
            sliced_input, batch_size=batch_size,
            capacity=10000, min_after_dequeue=batch_size*10)

现在,如果我使用新的tf.contrib.data.Dataset.from_tensor_slices创建我的输入函数,例如以下代码块(完整代码here),那么我生成的graph.pbtxt文件的大小会翻倍至1.3G并且.meta个文件的大小翻倍至330M。

def train_inputs():
    with tf.name_scope('Training_data'):
        images = mnist.train.images.reshape([-1, 28, 28, 1])
        labels = mnist.train.labels
        dataset = tf.contrib.data.Dataset.from_tensor_slices(
            (images, labels))
        dataset = dataset.repeat(None)  # Infinite
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        next_example, next_label = iterator.get_next()
        return next_example, next_label

现在因为graph.pbtxt文件太大TensorBoard需要很长时间来解析这个文件,而且我无法直观地调试我的模型图。 我在Dataset documentation中发现,这种增加的大小来自:"数组的内容将被多次复制" solution将是使用占位符。但是,在这种情况下,我需要使用活动会话将numpy数组输入占位符以初始化迭代器:

sess.run(iterator.initializer, feed_dict={features_placeholder: features, labels_placeholder: labels})

然而,在使用tf.contrib.learn.Experiment框架时,这似乎不受我的控制。

如何使用Experiment框架初始化迭代器的初始化程序?或者在不增加图表大小的情况下找到使用数据集API的变通方法?

1 个答案:

答案 0 :(得分:3)

我使用tf.train.SessionRunHook找到了解决问题的方法。我创建了一个SessionRunHook对象,在创建会话后初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initiliser_func = None

    def after_create_session(self, session, coord):
        self.iterator_initiliser_func(session)

创建数据集迭代器时设置初始化函数:

iterator_initiliser_hook.iterator_initiliser_func = \
    lambda sess: sess.run(
        iterator.initializer,
        feed_dict={images_placeholder: images,
                   labels_placeholder: labels})

我将钩子对象传递给train_monitors的{​​{1}}和eval_hooks参数。

生成的tf.contrib.learn.Experiment文件现在只有500K,而graph.pbtxt文件只有244K。

Full example here.