用常量初始化输入层时遇到麻烦

时间:2019-06-04 14:37:38

标签: python tensorflow keras

我的模型中有三个Input层,并且将“ input3”设置为常量值。然后,将“ input3”输入到Embedding层,获取结果“ lookup_table”,然后执行其他一些操作。

但是当我使用model.summary()来观察我的模型和训练参数时,我发现Input3层和Embedding层没有添加到模型中,并且我认为Embedding层的参数将不会被训练

我真的为此感到困扰,任何帮助将不胜感激!

notifyDatasetChanged()
The code 

import numpy as np
from keras.models import Model
from keras.layers import*
import keras.backend as K


np_constant = np.array([[1,2,3],
                        [4,5,6],
                        [7,8,9]])

def NN():
    input1 = Input(batch_shape=(None,1),name='input1',dtype='int32')
    input2 = Input(batch_shape=(None,1),name='input2',dtype='int32')
    # constant_tensor = K.constant(np_constant)
    input3 = Input(tensor=K.constant(np_constant),batch_shape=(3,3),dtype='int32',name='constant_input_3')
    embedding = Embedding(input_dim=10,output_dim=5,input_length=3)
    lookup_table = embedding(input3)
    lookup_table = Lambda(lambda x: K.reshape(x, (-1,15)))(lookup_table)

    output1 = Lambda(lambda x: K.gather(lookup_table, K.cast(x, dtype='int32')))(input1)
    output2 = Lambda(lambda x: K.gather(lookup_table, K.cast(x, dtype='int32')))(input2)

    # Merge branches
    output = Concatenate(axis=1)([output1, output2])
    # Process merged branch
    output = Dense(units=2
                   , activation='softmax'
                   )(output)

    model = Model([input1, input2, input3], outputs=output)
    return model

model = NN()
model.summary()
in_1 = np.array([1,2,1])
in_2 = np.array([1,0,1])
model.compile()  # just for example
model.fit([in_1,in_2])

我必须在model.fit()函数中提供数据,并且input3始终是常数,并且input3的形状不同于input1和input2,所以我以这种方式使用它。但是我不知道为什么没有将Input3层和Embedding层添加到模型中。

1 个答案:

答案 0 :(得分:0)

我修改了原始代码,在模型外部定义了一个自定义函数,并按照Anakin的建议将张量列表传递到Lambda层中。这是修改后的代码。

import numpy as np
from keras.models import Model
from keras.layers import*
import keras.backend as K


np_constant = np.array([[1,2,3],
                        [4,5,6],
                        [7,8,9]])

def look_up(arg):
    in1 = arg[0]
    in2 = arg[1]
    lookup_table = arg[2]

    in1 = Lambda(lambda x: K.reshape(x, (-1, )))(in1)
    in2 = Lambda(lambda x: K.reshape(x, (-1, )))(in2)

    output1 = Lambda(lambda x: K.gather(lookup_table, K.cast(x, dtype='int32')))(in1)
    output2 = Lambda(lambda x: K.gather(lookup_table, K.cast(x, dtype='int32')))(in2)
    return [output1,output2]

def NN():
    input1 = Input(batch_shape=(None,1),name='input1',dtype='int32')
    input2 = Input(batch_shape=(None,1),name='input2',dtype='int32')
    # constant_tensor = K.constant(np_constant)
    input3 = Input(tensor=K.constant(np_constant),batch_shape=(3,3),dtype='int32',name='constant_input_3')
    lookup_table = Embedding(input_dim=10,output_dim=5,input_length=3)(input3)
    lookup_table = Lambda(lambda x: K.reshape(x, (-1, 15)))(lookup_table)


    output1 = Lambda(look_up)([input1,input2,lookup_table])[0]
    output2 = Lambda(look_up)([input1,input2,lookup_table])[1]
    # Merge branches
    output = Concatenate(axis=1)([output1, output2])
    # Process merged branch
    output = Dense(units=2
                   , activation='softmax'
                   )(output)

    model = Model([input1, input2, input3], outputs=output)
    return model

model = NN()
model.summary()
input_1 = np.array([1,2,1])
input_2 = np.array([1,0,1])
model.compile()  # just for example
model.fit([input_1,input_2])

通过这种方式,可以将Embedding添加到模型中。而且input3是一个常数张量,我们不需要在model.fit()函数中提供它。

The model summary

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
constant_input_3 (InputLayer)   (3, 3)               0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (3, 3, 5)            50          constant_input_3[0][0]           
__________________________________________________________________________________________________
input1 (InputLayer)             (None, 1)            0                                            
__________________________________________________________________________________________________
input2 (InputLayer)             (None, 1)            0                                            
__________________________________________________________________________________________________
lambda_1 (Lambda)               (3, 15)              0           embedding_1[0][0]                
__________________________________________________________________________________________________
lambda_2 (Lambda)               [(None, 15), (None,  0           input1[0][0]                     
                                                                 input2[0][0]                     
                                                                 lambda_1[0][0]                   
__________________________________________________________________________________________________
lambda_11 (Lambda)              [(None, 15), (None,  0           input1[0][0]                     
                                                                 input2[0][0]                     
                                                                 lambda_1[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 30)           0           lambda_2[0][0]                   
                                                                 lambda_11[0][1]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            62          concatenate_1[0][0]              
==================================================================================================
Total params: 112
Trainable params: 112
Non-trainable params: 0
__________________________________________________________________________________________________