检查模型权重是否在 tensorflow 中发生了变化?

时间:2021-05-13 09:04:17

标签: python list numpy tensorflow

我现在正在训练一个生成对抗网络,我有一种不好的感觉,即权重正在以一种不应该的方式更新,所以我正在尝试改变它。但是由于数组和可变性,我无法使用 np.array_equal() 函数正确检查。

我的代码管道看起来有点像:

gen_1=gen.trainable_variables
disc_1=disc.trainable_variables
disc.train_on_batch()#train discriminator
disc_2=disc.trainable_variables #to get values after weight update

但由于这些是可变数组,我认为 disc_1 和 disc_2 指向相同的值。 np.copy() 失败,因为 model.trainable_variables 返回序列而不是数组

我实际上可以做些什么来解决这个问题?

0 个答案:

没有答案