当我将保存的模型加载到Keras中时,输入会丢失,而自定义丢失会占用多个输入

时间:2019-10-08 13:30:55

标签: python tensorflow machine-learning keras

tf。版本 '1.12.0'

我有一个自定义损失函数,需要多个输入。除非我尝试保存和加载模型,否则它工作正常。这是一个简单的示例,显示了输入丢失的方式。请看看。

x = tf.keras.Input(shape=(5,), name='input')
y_true = tf.keras.Input(shape=(5,), name='y_true' )
y_pred = tf.keras.layers.Dense(5)(x)
other_data = tf.keras.Input(shape=(5,), name='other_data' )
model = tf.keras.Model(inputs=[x, y_true, other_data],  outputs=y_pred)

def custom_loss(y_true, y_pred):
    return tf.reduce_sum(tf.pow(y_true -y_pred,2)) + tf.reduce_sum(tf.multiply(y_pred,other_data))

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01, beta_2=0.99), loss=custom_loss)

data = np.random.rand(5,5)
model.fit([data, data, data], data)

model.save('tmp.h5')
print(model.input_names)
model1 = tf.keras.models.load_model('tmp.h5', custom_objects={'custom_loss':custom_loss})
print(model1.input_names)

model1.fit([data, data, data], data)

  

Epoch 1/1 5/5 [=============================]-0s 42ms / step-损失:   9.4392

     

['input','y_true','other_data'] <--------------这很好

     

['input'] <-----------这里发生了什么?

     

回溯(最近通话最近一次):

     

文件“”,第21行,在       model1.fit([数据,数据,数据],数据)

     

文件   “ C:\ src \ Anaconda3 \ envs \ deepema \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py”,   1536行,适合       validate_split = validation_split)

     

文件   “ C:\ src \ Anaconda3 \ envs \ deepema \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py”,   _standardize_user_data中的第992行       class_weight,batch_size)

     

文件   “ C:\ src \ Anaconda3 \ envs \ deepema \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training.py”,   第1117行,在_standardize_weights中       exception_prefix ='input')

     

文件   “ C:\ src \ Anaconda3 \ envs \ deepema \ lib \ site-packages \ tensorflow \ python \ keras \ engine \ training_utils.py”,   第293行,位于standardize_input_data中       str(len(data))+'arrays:'+ str(data)[:200] +'...')

     

ValueError:检查模型输入时出错:Numpy数组的列表   您传递给模型的信息不是模型期望的大小。   预计将看到1个数组,但得到以下3个列表   数组:[array([[0.12768201,0.06106967,0.99779087,0.50767692,   0.21839594],          [0.82444334、0.1367274、0.14495117、0.32396153、0.24457874],          [0.29870316、0.40644681、0.69308081、0.30091417、0.776 ...

0 个答案:

没有答案