如何在具有可变输入图像大小的模型中获取图层的output_shape

时间:2017-09-01 20:37:22

标签: tensorflow keras

我正在修补一个VGG16卷积网络的Keras实现(我没有自己构建)。 Tensorflow后端。输入图像大小不同,所以我使用None指定第一层,用于可变宽度和高度。

model.add(ZeroPadding2D((1, 1), input_shape=(3, None, None)))

问题是,在构建损失的过程中,我需要得到一个卷积层的output_shape,当然这会出现一些未定义的维度。

我想知道是否有办法设置第一层的输入宽度和高度,只是为了从我的图层堆栈中间计算这个output_shape。我不熟悉算术,通过图层链自己计算。

我应该说我是一个菜鸟,因此会感谢冗长的答案。

2 个答案:

答案 0 :(得分:3)

您可以使用该图层中输出张量的形状,而不是使用图层的output_shapeK.shape(x)为您提供张量x的形状。动态轴(即None轴)将在运行时填充相应的宽度和高度。

这是一个示例,说明如何在自定义的损失中使用中间层的输出形状(损失函数本身没有意义,只是为了表明shape根据输入数组求值为不同的值) :

input_tensor = Input(shape=(3, None, None))
middle_tensor = Conv2D(100, 1)(input_tensor)  # shape = (100, None, None)
output_tensor = GlobalMaxPooling2D()(middle_tensor)  # not important
model = Model(input_tensor, output_tensor)

def get_loss(shape):
    def dummy_loss(y_true, y_pred):
        return K.cast(K.prod(shape), K.floatx())
    return dummy_loss
dummy_loss = get_loss(K.shape(middle_tensor))
model.compile(loss=dummy_loss, optimizer='sgd')

print(model.evaluate(np.zeros((1, 3, 2, 2)), np.zeros((1, 1))))
=> 400.0

print(model.evaluate(np.zeros((1, 3, 224, 224)), np.zeros((1, 1))))
=> 5017600.0

正如您所看到的,在第一次调用中,K.shape(middle_tensor)评估为(100, 2, 2),因此K.prod(shape)为400.在第二次调用中,K.shape(middle_tensor)评估为(100, 224, 224) } K.prod(shape)变为5017600。

答案 1 :(得分:0)

如果您想使用卷积图层(如VGG16),您必须将图像调整到正确的尺寸,如果您想使用预训练的重量,您必须使用与训练时相同的尺寸(224x224堡垒图像网络训练权重)。

你的ImageDataGenerator()可以为你调整大小(下面是img_height和omg_weight)

train_datagen=ImageDataGenerator()
valid_datagen=ImageDataGenerator()
train_generator =train_datagen.flow_from_directory(
    train_path,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical')
train_filenames = train_generator.filenames
train_samples = len(train_filenames)

validation_generator = validation_datagen.flow_from_directory(
    valid_path,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle = False) #Need this to be false, so I can extract the correct classes and filenames in order that that are predicted
validation_filenames = validation_generator.filenames
validation_samples = len(validation_filenames)