我使用Keras构建模型,模型中有两个输入,其数据类型为'int32'。然后,我使用keras Lamba层通过K.gather(引用,索引)在嵌入矩阵中查找。我看到索引应该是int的张量,我认为我的代码满足了这一点,我不知道为什么会出错。我真的需要帮助!
input_A = Input(batch_shape=(128,1),name='A_input',dtype='int32')
input_B = Input(batch_shape=(128,1),name='B_input',dtype='int32')
input_A_ = Lambda(lambda x:K.reshape(x,(-1,)))(input_A)
input_B_ = Lambda(lambda x:K.reshape(x, (-1,)))(input_B)
input_A__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_A_)
input_B__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_B_)
embedded_text_A = Lambda(lambda x:K.gather(M1,x))(input_A__)
embedded_text_B = Lambda(lambda x:K.gather(M1,x))(input_B__)
答案 0 :(得分:0)
出于某种神秘的原因,如果将K.cast()
放在lambda
内,它将正常工作:
input_A = Input(batch_shape=(128,1), name='A_input', dtype='int32')
input_B = Input(batch_shape=(128,1), name='B_input', dtype='int32')
input_A_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_A)
input_B_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_B)
embedded_text_A = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_A_)
embedded_text_B = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_B_)
因此,Lambda
层在其中进行了一些奇怪的dtype转换。
我认为这是某种错误,我的假设是隐式转换发生在Lambda
的{{1}} (which is inherited from Layer.__call__
)内部。我无法跟踪它,但是我猜想“隐式转换”错误位于__call__
中的某个位置,但是在451行之前,实际上是在Layer.__call__
处被调用的。