如何在不影响其他模型的情况下冻结一个模型的子模型?

时间:2020-04-17 15:43:59

标签: tensorflow keras

我正在尝试制作像GAN这样的模型。但是我无法弄清楚如何仅将一个模型正确地设置为False。似乎所有使用子模型的模型都会受到影响。

代码:

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense

print(tf.__version__)

def build_submodel():
  inp = tf.keras.Input(shape=(3,))
  x = Dense(5)(inp)
  model = Model(inputs=inp, outputs=x)
  return model

def build_model_A():
  inp = tf.keras.Input(shape=(3,))
  x = submodel(inp)
  x = Dense(7)(x)
  model = Model(inputs=inp, outputs=x)
  return model

def build_model_B():
  inp = tf.keras.Input(shape=(11,))
  x = Dense(3)(inp)
  x = submodel(x)
  model = Model(inputs=inp, outputs=x)
  return model

submodel = build_submodel()
model_A = build_model_A()
model_A.compile("adam", "mse")
model_A.summary()
submodel.trainable = False
# same result with freezing layers
# for layer in submodel.layers:
#   layer.trainable = True
model_B = build_model_B()
model_B.compile("adam", "mse")
model_B.summary()

model_A.summary()

输出:

Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 3)]               0         
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
_________________________________________________________________
dense_10 (Dense)             (None, 7)                 42        
=================================================================
Total params: 62
Trainable params: 62
Non-trainable params: 0
_________________________________________________________________
Model: "model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_12 (InputLayer)        [(None, 11)]              0         
_________________________________________________________________
dense_11 (Dense)             (None, 3)                 36        
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
=================================================================
Total params: 56
Trainable params: 36
Non-trainable params: 20
_________________________________________________________________
Model: "model_10"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 3)]               0         
_________________________________________________________________
model_9 (Model)              (None, 5)                 20        
_________________________________________________________________
dense_10 (Dense)             (None, 7)                 42        
=================================================================
Total params: 62
Trainable params: 42
Non-trainable params: 20
_________________________________________________________________

首先,model_A没有不可训练的权重。但是在建立model_B之后。 model_A具有一些不可训练的权重。

此外,摘要未显示哪些层是不可训练的,仅显示了不可训练的参数总数。有没有更好的方法来检查模型中冻结了哪些层?

1 个答案:

答案 0 :(得分:0)

您可以使用此功能显示哪个图层可训练

def print_params(model):

  def count_params(weights):
      """Count the total number of scalars composing the weights.
      # Arguments
          weights: An iterable containing the weights on which to compute params
      # Returns
          The total number of scalars composing the weights
      """
      weight_ids = set()
      total = 0
      for w in weights:
          if id(w) not in weight_ids:
              weight_ids.add(id(w))
              total += int(K.count_params(w))
      return total

  trainable_count = count_params(model.trainable_weights)
  non_trainable_count = count_params(model.non_trainable_weights)

  print('id\ttrainable : layer name')
  print('-------------------------------')
  for i, layer in enumerate(model.layers):
      print(i,'\t',layer.trainable,'\t  :',layer.name)
  print('-------------------------------')

  print('Total params: {:,}'.format(trainable_count + non_trainable_count))
  print('Trainable params: {:,}'.format(trainable_count))
  print('Non-trainable params: {:,}'.format(non_trainable_count))

它将像这样输出

id  trainable : layer name
-------------------------------
0    False    : input_1
1    False    : block1_conv1
2    False    : block1_conv2
3    False    : block1_pool
4    False    : block2_conv1
5    False    : block2_conv2
6    False    : block2_pool
7    False    : block3_conv1
8    False    : block3_conv2
9    False    : block3_conv3
10   False    : block3_pool
11   False    : block4_conv1
12   False    : block4_conv2
13   False    : block4_conv3
14   False    : block4_pool
15   False    : block5_conv1
16   False    : block5_conv2
17   False    : block5_conv3
18   False    : block5_pool
19   True     : global_average_pooling2d
20   True     : dense
21   True     : dense_1
22   True     : dense_2
-------------------------------
Total params: 15,245,130
Trainable params: 530,442
Non-trainable params: 14,714,688