我正在尝试在Keras中开发一个可用于3D张量的图层。为了使其更加灵活,我想尽可能推迟依赖于输入的确切形状的代码。
我的图层覆盖了5种方法:
from tensorflow.python.keras.layers import Layer
class MyLayer(Layer):
def __init__(self, **kwargs):
pass
def build(self, input_shape):
pass
def call(self, inputs, verbose=False):
second_dim = K.int_shape(inputs)[-2]
# Do something with the second_dim
def compute_output_shape(self, input_shape):
pass
def get_config(self):
pass
我正在像这样使用这一层:
input = Input(batch_shape=(None, None, 128), name='input')
x = MyLayer(name='my_layer')(input)
model = Model(input, x)
但是我遇到一个错误,因为second_dim
是None
。我该如何开发一个依赖于输入维度的图层,但是可以由实际数据而不是输入图层来提供呢?
答案 0 :(得分:0)
我最终以不同的方式问了同样的问题,并且我得到了一个完美的答案:
What is the right way to manipulate the shape of a tensor when there are unknown elements in it?
要点是,不要直接处理尺寸。通过引用而不是按值使用它们。因此,请勿使用K.int_shape
,而应使用K.shape
。并使用Keras运算来合成并提出新的形状:
shape = K.shape(x)
newShape = K.concatenate([
shape[0:1],
shape[1:2] * shape[2:3],
shape[3:4]
])