Model.trainable = False与Model.compile()

时间:2019-06-25 23:23:48

标签: python keras

这些说法正确吗?

    除非进行编译,否则
  • Model.trainable = False本身对任何编译的东西绝对没有作用。
  • 如果我在已编译(ModelA)的ModelA.compile(...)中使用两层,请创建一个跳过模型ModelB=Model(intermediate_layer1, intermediate_layer2)并设置ModelB.trainable=FalseModelB.compile(...),什么也不要将更改为ModelA;假设尚未触及可训练对象,如果仅训练ModelA(ModelA),ModelA.fit(...)中的所有内容的权重都会更新
  • 这仅与重量更新有关,因此重量将被保存/加载而不会出现问题(即使它是错误重量)。

当我尝试训练GAN时一切都开始了,当训练发电机并得到以下警告时冻结了鉴别器:

 UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?

我调查了一下,发现人们也对此进行了调查:

https://github.com/keras-team/keras/issues/8585

这是根据该Issue线程改编的示例:

# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)

# making gan(generator -> discriminator)
discriminator.trainable = False # CHECK THIS OUT!
gan = Model(inputs=g_input, outputs=discriminator(g_output))
gan.compile(loss='categorical_crossentropy', optimizer='adam')

# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...

当有人说这是一个错误的警告而有人说重量可能被弄乱时,我感到困惑。

然后我读了一个问题:shouldn't model.trainable=False freeze weights under the model?

这篇文章很好地解释了“可训练”的实际作用。我想知道我的理解是否正确,并确保我的GAN训练正确。

0 个答案:

没有答案