我现在正在训练一个生成对抗网络,我有一种不好的感觉,即权重正在以一种不应该的方式更新,所以我正在尝试改变它。但是由于数组和可变性,我无法使用 np.array_equal()
函数正确检查。
我的代码管道看起来有点像:
gen_1=gen.trainable_variables
disc_1=disc.trainable_variables
disc.train_on_batch()#train discriminator
disc_2=disc.trainable_variables #to get values after weight update
但由于这些是可变数组,我认为 disc_1 和 disc_2 指向相同的值。 np.copy()
失败,因为 model.trainable_variables
返回序列而不是数组
我实际上可以做些什么来解决这个问题?