如何在Tensorflow2 Keras模型损失函数中添加可训练的参数

时间:2020-09-06 20:47:47

标签: python tensorflow keras

我正在尝试使用Keras(Tensorflow2)训练图像降噪器网络。对于损失函数,我想使用类似( a1 * L1_loss + a2 * L2_loss)的方式,其中 a1 a2 < / strong>是可训练的,这意味着在我给他们初始值之后,他们可以在每次训练迭代中得到更新。但是我在这里停留了一段时间,并且确实知道该如何实施。

这是一些示例代码,

m

我的损失函数定义为

new Cell()

然后我使用fit()函数传递包含训练数据的tf.data.Dataset进行训练。

尽管我可以通过这种方式添加两个权重参数,但是这些权重是不可训练的,并且它们不会随着训练而改变。如果有人对这个问题有一些想法,我真的希望能得到一些提示或例子。任何帮助表示赞赏!

0 个答案:

没有答案