如何开发适用于任意大小输入的图层

时间:2019-10-07 01:24:21

标签: keras

我正在尝试在Keras中开发一个可用于3D张量的图层。为了使其更加灵活,我想尽可能推迟依赖于输入的确切形状的代码。

我的图层覆盖了5种方法:

from tensorflow.python.keras.layers import Layer

class MyLayer(Layer):
    def __init__(self, **kwargs):
        pass

    def build(self, input_shape):
        pass

    def call(self, inputs, verbose=False):
        second_dim = K.int_shape(inputs)[-2]
        # Do something with the second_dim

    def compute_output_shape(self, input_shape):
        pass

    def get_config(self):
        pass

我正在像这样使用这一层:

input = Input(batch_shape=(None, None, 128), name='input')
x = MyLayer(name='my_layer')(input)
model = Model(input, x)

但是我遇到一个错误,因为second_dimNone。我该如何开发一个依赖于输入维度的图层,但是可以由实际数据而不是输入图层来提供呢?

1 个答案:

答案 0 :(得分:0)

我最终以不同的方式问了同样的问题,并且我得到了一个完美的答案:

What is the right way to manipulate the shape of a tensor when there are unknown elements in it?

要点是,不要直接处理尺寸。通过引用而不是按值使用它们。因此,请勿使用K.int_shape,而应使用K.shape。并使用Keras运算来合成并提出新的形状:

shape = K.shape(x)
newShape = K.concatenate([
                             shape[0:1], 
                             shape[1:2] * shape[2:3],
                             shape[3:4]
                         ])