我正在尝试在keras中编写自定义损失函数,该函数基于训练过程中模型的多个预测。 我以前有一个列表/字典,例如dict = {X1:[X2,X3,X4],X2:[X1,X6,X7]}等。
鉴于输入X1,我需要在训练期间预测X2,X3,X4。
model.predict无法正常工作,出现错误 ValueError:将符号张量馈送到模型时,我们期望张量具有静态批大小。得到了张量形状:(无,36)
from keras.layers import *
from keras.models import Model
import keras.backend as K
input_tensor = Input(shape=(36,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='linear')(hidden)
def custom_loss(input_tensor, dict):
def inner(y_true, y_pred):
X2 = dict[input_tensor][0]
X3 = dict[input_tensor][1]
X4 = dict[input_tensor][2]
X2_pred = model.predict(X2)
X3_pred = model.predict(X3)
X4_pred = model.predict(X4)
return K.mean(max(X2_pred, X3_pred, X4_pred)-y_true)
return inner
custom_loss_final = custom_loss(input_tensor = input_tensor, dict = dict)
model = Model(input_tensor, out)
model.compile(loss = custom_loss_final, optimizer='adam')
model.fit(x = Train_X, y = Train_y, batch_size= 100)
基于 Anakin's 解决方案进行编辑: 我尝试了您的代码,实际上是在np.append处,我需要使用axis = 0。
现在我有: INPUT_X.shape (100,36) INPUT_Y.shape (100,3,36)
INPUT_X:我有100个训练样本,每个样本都是36 len数组。
INPUT_Y:这些是X2,X3等
实际上,我甚至不需要在代码中使用Y_true,因为我将使用模型(X2)等。
我按照您的建议将它们传递给model.fit,在损失函数中我打印了一些类型/形状:
(input_tensor)
(?,36)
(“ y_true”)
(?,?)
(pred_y)
(?,1)
我不知道为什么y_true形状是(?,?)而不是(3,36)。我不能放模型(pred_y [:,0]),因为我得到了:
ValueError:对于输入格式为[?],[36,300]的'loss_56 / dense_114_loss / model_57 / dense_113 / MatMul'(运算符:'MatMul'),形状必须为2级,但为1级。
##为了清楚起见,现在隐藏层的大小为300。
答案 0 :(得分:0)
您可以在输入中用额外的数据列填充标签,并编写自定义损失。您可以传递额外的预测信息。您输入的是像这样的numpy数组
def custom_loss(data, y_pred):
y_true = data[:, 0]
extra = data[:, 1:]
return K.mean(K.square(y_pred - y_true), axis=-1) + something with extra...
def baseline_model():
# create model
i = Input(shape=(5,))
x = Dense(5, kernel_initializer='glorot_uniform', activation='linear')(i)
o = Dense(1, kernel_initializer='normal', activation='linear')(x)
model = Model(i, o)
model.compile(loss=custom_loss, optimizer=Adam(lr=0.0005))
return model
model.fit(X, np.append(Y_true, extra, axis =1), batch_size = batch_size, epochs=90, shuffle=True, verbose=1)
该示例摘自我的代码的一部分,但希望您能理解。