使用`tf.train.batch`时保留[None,None,None,x]输入形状

时间:2018-02-15 19:22:28

标签: python tensorflow

我正在尝试训练一个完全卷积模型,它可以使用张量流1.5来获取任何输入分辨率的图像。

目前我正在做这样的事情:

image, segmentation = \
    D.get_training_dataset_data_provider()

image, segmentation = \
    tf.train.batch([image, segmentation],
                    batch_size=16)

# Define the model:
predictions, loss, end_points = M.model_w_batch_norm(
    image,
    segmentation
            )

然后我继续使用tf.train.MonitoredTrainingSession训练模型。一切都很好,除了我得到一个具有固定输入形状的模型(即image_in具有我从我的数据集数据提供者提供的形状)。

get_training_dataset_data_provider加载来自tfrecord的图像并对其进行扩充。输出是恒定分辨率,适合训练。然而,对于测试/预测,我希望能够传递任何形状的图像。

模型本身是完全卷积的(在M.model_w_batch_norm中)定义了张量image_in

def model_w_batch_norm(in_image, trainable=True):
    in_image = tf.identity(in_image, name="image_in")
    ...

如果我不使用tf.train.batch,我只会定义一个Placeholder形状(无,无,无,x),但我想我无法将tf.train.batch链接到占位符,我可以吗?

如何为形状(None, None, None, x)的模型定义输入张量?

1 个答案:

答案 0 :(得分:0)

实际上,如果模型的代码可用,则解决方案变得简单。在这种情况下,您可以使用

重建模型
ph = tf.Placeholder(shape=[None, None, None, 3])
predictions, loss, end_points = M.model_w_batch_nor
m(
    ph,
    segmentation
            )

在我的例子中。然后使用Saver创建slim.get_variables_to_restore()并从检查点恢复模型。这样就构建了一个不包含批处理运算符的新图。