如何使用tf.train.shuffle_batch为训练和推理构建TF图?

时间:2016-10-31 10:58:44

标签: tensorflow

上下文:卷积自动编码器

我使用tf.add_to_collections()

将input_tensor和output_tensor添加到图表中
fee

在推理期间,我调用tf.get_collections()来提取这些节点,然后调用sess.run()。它说形状固定在[512,32,32,3],我该如何解决?

1 个答案:

答案 0 :(得分:1)

一个可能的解决方案是使用tf.placeholder_with_default() op来放宽输入操作上的形状要求。例如:

50

如果您运行的代码取决于input_batch = tf.train.shuffle_batch(..., batch_size=512) input_placeholder = tf.placeholder_with_default(input_batch, [None, None, None, 3]) 但不提供,则会使用input_placeholder的结果。或者,如果您为tf.train.shuffle_batch()提供值,则可以输入任何4-D张量(深度为3),因此您可以使用任何批量大小或图像大小。

但请注意,这样做会禁用训练中的一些优化,因为每个批次的形状现在可以变化,至少原则上如此。这可以防止TensorFlow将某些内部tf.shape()调用视为常量值,这可能意味着它需要在每个训练步骤中执行更多工作。最后,为培训和推理构建两个单独的图可能会更好,因为这些图可以单独优化。