tensorflow keras中model.trainable = False的预期行为和目的是什么

时间:2019-06-07 14:36:59

标签: tensorflow keras deep-learning keras-layer tf.keras

似乎在tensorflow keras中设置model.trainable=False除了打印错误的model.summary()之外什么都不做。这是重现此问题的代码:

import tensorflow as tf
import numpy as np
IMG_SHAPE = (160, 160, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False, 
                                               weights='imagenet')
base_model.trainable = False
# for layer in base_model.layers:
#     layer.trainable=False
bc=[] #before compile
ac=[] #after compile
for layer in base_model.layers:
    bc.append(layer.trainable)
print(np.all(bc)) #True
print(base_model.summary()) ##this changes to show no trainable parameters but that  is wrong given the output to previous np.all(bc)
base_model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])
for layer in base_model.layers:
    ac.append(layer.trainable)
print(np.all(ac)) #True
print(base_model.summary()) #this changes to show no trainable parameters but that  is wrong given the output to previous np.all(ac)

鉴于此-在tensorflow keras中,model.trainable = False的预期行为和目的是什么?

1 个答案:

答案 0 :(得分:0)

https://github.com/tensorflow/tensorflow/issues/29535

我认为这个问题可能会有所帮助。

如果您正在寻找一种不更新模型权重的方法,建议您在var_list的{​​{1}}函数中使用参数minimize

出于某种原因,当从keras Tensorflow创建模型时,将所有tf.Variables都切换为True,并且由于都是Tensors,因此我们无法将其值更新为False。

我在代码中所做的是为所有预训练的模型创建范围名称,并在其上循环,添加非预训练模型的所有图层。

Optimizer

还要注意global_initializer,因为它也会覆盖您的预训练权重。您可以使用trainable_variables = [] variables_collection = tf.get_collection('learnable_variables') for layer in tf.trainable_variables(): if 'vgg_model' not in layer.name: trainable_variables.append(layer) tf.add_to_collection('learnable_variables', layer) grad = tf.train.GradientDescentOptimizer(lr) train_step = grad.minimize(tf.reduce_sum([loss]), var_list=trainable_variables) 并传递要添加权重的变量列表来解决该问题。

tf.variables_initializer

我尝试解决此问题时使用的来源 Is it possible to make a trainable variable not trainable?

TensorFlow: Using tf.global_variables_initializer() after partially loading pre-trained weights