我想为两个损失函数动态更改损失权重。但是,当我在代码中使用Callbacks时,该错误表明损失权重变量未初始化。我已经用纯Keras测试过类似的代码,它可以正常运行并且没有初始化问题。 tensorflow样式的代码如下:
# dynamic alpha weights for two losses
class MyCallback(Callback):
def __init__(self, alpha):
self.alpha = alpha
# customize your behavior
def on_epoch_end(self, epoch, logs={}):
results = [logs['sumrate_loss_loss'], logs['Wnorm_loss_loss']]
K.set_value(self.alpha, results[0] / (results[0] + results[1]) )
# CNN model
input_H = Input(shape=(Nr, Nt))
input_H_expand = Lambda(lambda x: K.expand_dims(x))(input_H)
input_sigma2 = Input(shape=(Nr,))
temp = Conv2D(Nr * Nt, (1, 2), activation='relu', padding='same')(input_H_expand)
temp = Conv2D(Nr * Nt, (1, 2), activation='relu', padding='same')(temp)
temp = Flatten()(temp)
output_W = Dense(Nt * Nr, activation='linear')(concatenate([temp, input_sigma2], axis=-1))
output_W = Lambda(lambda x: K.reshape(x, [-1, Nt, Nr]))(output_W)
input_I_DC = Input(shape=(1, ))
sumrate_loss = Lambda(sumrate, name='sumrate_loss')([input_H, output_W, input_sigma2])
W_norm_reg = Lambda(norms_reg, name='Wnorm_loss')([output_W, input_I_DC])
model = Model(inputs=[input_H, input_sigma2, input_I_DC], outputs=[sumrate_loss, W_norm_reg])
alpha = tf.Variable(0.5)
model.compile(optimizer='adam',
loss={
'sumrate_loss': lambda y_true, y_pred: y_pred,
'Wnorm_loss': lambda y_true, y_pred: y_pred},
loss_weights={
'sumrate_loss': alpha,
'Wnorm_loss': 1 - alpha},
)
# fit model
reduce_lr = ReduceLROnPlateau(monitor='val_loss',
factor=0.2,
patience=20,
min_lr=0.0005)
checkpoint = ModelCheckpoint('./temp_trained_cnn_maxsumrate_p1_basic.h5', monitor='val_loss',
verbose=0, save_best_only=True, mode='min', save_weights_only=True)
input_fit = [H_sample / np.min(np.sqrt(sigma2n_sample[::, ::, 10])), sigma2n_sample[::, ::, 10] / np.min(sigma2n_sample[::, ::, 10]), I_DC_sample[::, 10]]
random_y = [np.random.random(H_sample.shape[0]), np.random.random(H_sample.shape[0])]
history = model.fit(input_fit, random_y,
batch_size=16,
epochs=10,
verbose=1,
validation_split=0.2,
callbacks=[MyCallback(alpha)])
错误输出是MyCallback中的Alpha尚未初始化。 详细的错误是:
---> 14个callbacks = [MyCallback(alpha)])
FailedPreconditionError:尝试使用未初始化的值Variable_7 [[{{node Variable_7 / read}} = IdentityT = DT_FLOAT,_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]]
我希望有人能告诉我如何在loss_weights中初始化alpha。