keras的compute_output_shape不适用于自定义层

时间:2019-09-24 02:39:11

标签: keras shapes batchsize

我定制了一个图层,将batch_size和第一个维度合并,其他维度保持不变,但是compute_output_shape似乎没有任何效果,导致随后的图层无法获取准确的形状信息,从而导致错误。如何使compute_output_shape工作?

import keras
from keras import backend as K

class BatchMergeReshape(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(BatchMergeReshape, self).__init__(**kwargs)

    def build(self, input_shape):
        super(BatchMergeReshape, self).build(input_shape)  

    def call(self, x):
        input_shape = K.shape(x)
        batch_size, seq_len = input_shape[0], input_shape[1]
        r = K.reshape(x, (batch_size*seq_len,)+input_shape[2:])
        print("call_shape:",r.shape)
        return r

    def compute_output_shape(self, input_shape):
        if input_shape[0] is None:
            r = (None,)+input_shape[2:]
            print("compute_output_shape:",r)
            return r
        else:
            r = (input_shape[0]*input_shape[1],)+input_shape[2:]
            return r

a = keras.layers.Input(shape=(3,4,5))
b = BatchMergeReshape()(a)
print(b.shape)

# call_shape: (?, ?)
# compute_output_shape: (None, 4, 5)
# (?, ?)

我需要获取(None,4,5)但要获取(None,None),为什么compute_output_shape不起作用。我的keras版本是2.2.4

1 个答案:

答案 0 :(得分:0)

问题可能是K.shape返回了一个张量而不是一个元组。您不能做(batch_size*seq_len,) + input_shape[2:]。这混合了很多东西,张量和元组,结果肯定是错误的。

现在的好处是,如果您知道其他尺寸而不是批量大小,则只需要这一层:

Lambda(lambda x: K.reshape(x, (-1,) + other_dimensions_tuple))

如果不这样:

input_shape = K.shape(x)
new_batch_size = input_shape[0:1] * input_shape[1:2] #needs to keep a shape of an array   
                 #new_batch_size.shape = (1,)
new_shape = K.concatenate([new_batch_size, input_shape[2:]]) #this is a tensor   
                                                             #result of concatenating 2 tensors   

r = K.reshape(x, new_shape)

请注意,这在Tensorflow中有效,但在Theano中可能无效。

还请注意,Keras将要求模型输出的批量大小等于模型输入的批量大小。这意味着您将需要在模型结束之前恢复原始的批次大小。