上下文:卷积自动编码器
我使用tf.add_to_collections()
将input_tensor和output_tensor添加到图表中fee
在推理期间,我调用tf.get_collections()来提取这些节点,然后调用sess.run()。它说形状固定在[512,32,32,3],我该如何解决?
答案 0 :(得分:1)
一个可能的解决方案是使用tf.placeholder_with_default()
op来放宽输入操作上的形状要求。例如:
50
如果您运行的代码取决于input_batch = tf.train.shuffle_batch(..., batch_size=512)
input_placeholder = tf.placeholder_with_default(input_batch, [None, None, None, 3])
但不提供,则会使用input_placeholder
的结果。或者,如果您为tf.train.shuffle_batch()
提供值,则可以输入任何4-D张量(深度为3),因此您可以使用任何批量大小或图像大小。
但请注意,这样做会禁用训练中的一些优化,因为每个批次的形状现在可以变化,至少原则上如此。这可以防止TensorFlow将某些内部tf.shape()
调用视为常量值,这可能意味着它需要在每个训练步骤中执行更多工作。最后,为培训和推理构建两个单独的图可能会更好,因为这些图可以单独优化。