Lambda slice导致InvalidArgumentError:形状不兼容

时间:2019-08-31 20:00:27

标签: keras keras-layer

当我使用Lambda层对张量进行切片,然后对其进行连接时,我的模型最终会产生错误:“ InvalidArgumentError:不兼容的形状:[4,2]与[8]”

我尝试在各个地方指定input_shape和output_shape。这似乎没有什么不同。我也尝试过在Concatenate层中更改axis参数。如果使用-1(默认值),则会出现相同的错误。如果使用0,则会收到一个不同的错误,指示当期望(1,)时输出形状为(2,)。 axis = 1会出现错误,指示输入的等级不足。

我在下面制作了简短的代码来重现问题。

该模型采用形状(None,3)的输入,并假定将(None,2)的输出作为类别。构成数据只是为了说明问题。

即使这是一个没有参数的小模型,但我的大模型却得到了相同的错误消息。

仅供参考,我正在使用Google Colab。

# repro of - InvalidArgumentError: Incompatible shapes
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import optimizers
from keras.layers import Concatenate
from keras.layers import Input
from keras.layers import Lambda
from keras.models import Model

print ("tensorflow version %s" % tf.__version__)
print ("keras version %s" % keras.__version__)

# Mock data
X = pd.DataFrame([[0,1,0], [0,2,1], [7,9,0], [4,5,6]])
Y = pd.DataFrame([[0,1], [0,1], [0,1], [1,0]])
print(X)
print("Multi-layer NN training. X size %s. Y size %s."
      % (X.shape, Y.shape))

input_layer_X = Input(shape=(X.shape[1],))
a_slice = Lambda(lambda x: x[:, 0], output_shape=(1,), name='a_slice')(input_layer_X)
b_slice = Lambda(lambda x: x[:, 1], output_shape=(1,), name='b_slice')(input_layer_X)
concat_layer = Concatenate(name='concat_example', axis=-1)([a_slice, b_slice])

model = Model(input_layer_X, concat_layer)
sgd = optimizers.SGD(lr=0.00001)
model.compile(sgd, loss='categorical_crossentropy')
print(model.summary())

model.fit(X, Y, epochs=3)

模型拟合将运行。相反,我从model.fit调用中得到了此错误消息:

InvalidArgumentError: Incompatible shapes: [4,2] vs. [8]
     [[{{node loss_8/concat_example_loss/mul}}]]

但是据我所知,模型摘要看起来像我已经正确设置了:

Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_9 (InputLayer)            (None, 3)            0                                            
__________________________________________________________________________________________________
a_slice (Lambda)                (None, 1)            0           input_9[0][0]                    
__________________________________________________________________________________________________
b_slice (Lambda)                (None, 1)            0           input_9[0][0]                    
__________________________________________________________________________________________________
concat_example (Concatenate)    (None, 2)            0           a_slice[0][0]                    
                                                                 b_slice[0][0]                    
==================================================================================================

TF和keras版本信息:

tensorflow version 1.14.0
keras version 2.2.4-tf

感谢您的帮助。

0 个答案:

没有答案