在Tensorflow 2.3中,如何以一种总损失来训练多个模型?

时间:2020-09-23 09:38:53

标签: tensorflow tensorflow2.x

我正在尝试互相训练多个模型。 有两个以上的模型,每个模型都有各自的交叉熵损失。 此外,这些模型存在kl散度损失,用于比较分布。 然后,我想通过将所有这些损失加在一起以创建一个总损失来一次训练多个模型。

我的代码如下:

modelA = resnet18(name='modelA')
modelB = resnet18(name='modelB')


cls_criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
kld_criterion = tf.keras.losses.KLDivergence(reduction=tf.keras.losses.Reduction.SUM)

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:

        outputA = modelA(images, training=True)
        outputB = modelB(images, training=True)

        lossA = cls_criterion(labels, outputA) + kld_criterion(tf.nn.softmax(outputB, axis=1), tf.nn.softmax(outputA, axis=1))
        lossB = cls_criterion(labels, outputB) + kld_criterion(tf.nn.softmax(outputA, axis=1), tf.nn.softmax(outputB, axis=1))
    loss = lossA + lossB
    
    gradients = tape.gradient(loss, [ v for v in model.trainable_variables for model in [modelA, modelB] ])
    optimizer.apply_gradients(zip(gradients, [ v for v in model.trainable_variables for model in [modelA, modelB] ]))
      

然后,我遇到如下错误:

ValueError: in user code:

    <ipython-input-24-6c928868de5d>:64 train_step  *
        optimizer.apply_gradients(zip(gradients, [ v for v in model.trainable_variables for mode in model_list ]))
    /home/anaconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:513 apply_gradients  **
        grads_and_vars = _filter_grads(grads_and_vars)
    /home/anaconda3/envs/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:1271 _filter_grads
        ([v.name for _, v in grads_and_vars],))

    ValueError: No gradients provided for any variable: ['Conv1/kernel:0', 'Conv1/kernel:0', 'Conv1/kernel:0', 
'bn_Conv1/gamma:0', 'bn_Conv1/gamma:0', 'bn_Conv1/gamma:0', 
'bn_Conv1/beta:0', 'bn_Conv1/beta:0', 'bn_Conv1/beta:0', 
'expanded_conv_depthwise/depthwise_kernel:0', 
'expanded_conv_depthwise/depthwise_kernel:0', 
'expanded_conv_depthwise/depthwise_kernel:0', 
'expanded_conv_depthwise_BN/gamma:0', 
'expanded_conv_depthwise_BN/gamma:0', 
'expanded_conv_depthwise_BN/gamma:0', 
'expanded_conv_depthwise_BN/beta:0', 
'expanded_conv_depthwise_BN/beta:0', ....

0 个答案:

没有答案