这些说法正确吗?
Model.trainable = False
本身对任何编译的东西绝对没有作用。ModelA
)的ModelA.compile(...)
中使用两层,请创建一个跳过模型ModelB=Model(intermediate_layer1, intermediate_layer2)
并设置ModelB.trainable=False
,ModelB.compile(...)
,什么也不要将更改为ModelA
;假设尚未触及可训练对象,如果仅训练ModelA(ModelA
),ModelA.fit(...)
中的所有内容的权重都会更新当我尝试训练GAN时一切都开始了,当训练发电机并得到以下警告时冻结了鉴别器:
UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ?
我调查了一下,发现人们也对此进行了调查:
https://github.com/keras-team/keras/issues/8585
这是根据该Issue线程改编的示例:
# making discriminator
d_input = Input(shape=(2,))
d_output = Activation('softmax')(Dense(2)(d_input))
discriminator = Model(inputs=d_input, outputs=d_output)
discriminator.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
# making generator
g_input = Input(shape=(2,))
g_output = Activation('relu')(Dense(2)(g_input))
generator = Model(inputs=g_input, outputs=g_output)
# making gan(generator -> discriminator)
discriminator.trainable = False # CHECK THIS OUT!
gan = Model(inputs=g_input, outputs=discriminator(g_output))
gan.compile(loss='categorical_crossentropy', optimizer='adam')
# training
BATCH_SIZE = 3
some_input_data = np.array([[1,2],[3,4],[5,6]])
some_target_data = np.array([[1,1],[2,2],[3,3]])
# update discriminator
generated = generator.predict(some_input_data, verbose=0)
X = np.concatenate((some_target_data, generated), axis=0)
y = [[0,1]]*BATCH_SIZE + [[1,0]]*BATCH_SIZE
d_metrics = discriminator.train_on_batch(X, y)
# update generator
g_metrics = gan.train_on_batch(some_input_data, [[0,1]]*BATCH_SIZE)
# loop these operations for batches...
当有人说这是一个错误的警告而有人说重量可能被弄乱时,我感到困惑。
然后我读了一个问题:shouldn't model.trainable=False freeze weights under the model?
这篇文章很好地解释了“可训练”的实际作用。我想知道我的理解是否正确,并确保我的GAN训练正确。