如何在热切的执行模式下冻结tf.keras框架内的tensorflow变量?

时间:2019-06-06 15:02:35

标签: tensorflow lstm recurrent-neural-network tf.keras gated-recurrent-unit

我试图微调循环单元中的输入权重,而又不让反向传播影响以前的状态(n = 1的截断反向传播类型)。我正在使用tf.keras并渴望在tensorflow中执行。

我找不到冻结GRU单元特定部分的方法。特别是循环内核。似乎递归内核是张量流变量,因此,我找不到将可训练属性设置为False的方法。

我的代码基于this tutoral关于text_generation(google colab version,您可以在其中修改build_model函数并对其进行测试)


def build_model(vocab_size, embedding_dim, rnn_units, batch_size, freeze_embedding_layer=False, freeze_recurrent_kernel=False):

    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
        rnn(rnn_units, return_sequences=True,recurrent_initializer='glorot_uniform', stateful=True),
        tf.keras.layers.Dense(vocab_size)
    ])

    if freeze_embedding_layer:
      print("embedding type:", model.layers[0])
      model.layers[0].trainable = False

    if freeze_recurrent_kernel:
      print("rnn type:",type(model.layers[1]))
      print("rnn recurrent kernel type:", type(model.layers[1].recurrent_kernel))
      model.layers[1].recurrent_kernel.trainable = False


    return model

当调用此函数时,例如:


# Length of the vocabulary in chars
vocab_size = len(vocab)

# The embedding dimension
embedding_dim = 256

# Number of RNN units
rnn_units = 1024

if tf.test.is_gpu_available():
    rnn = tf.keras.layers.CuDNNGRU
else:
    import functools
    rnn = functools.partial(
        tf.keras.layers.GRU, recurrent_activation='sigmoid')

model = build_model(
  vocab_size = len(vocab),
  embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE, 
  freeze_embedding_layer=True, 
  freeze_recurrent_kernel=True)

我得到:

embedding type: <tensorflow.python.keras.layers.embeddings.Embedding object at 0x7f955a198d68>
rnn type: <class 'tensorflow.python.keras.layers.cudnn_recurrent.CuDNNGRU'>
rnn recurrent kernel type: <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-19-1677e05c2afc> in <module>()
      3   embedding_dim=embedding_dim,
      4   rnn_units=rnn_units,
----> 5   batch_size=BATCH_SIZE, freeze_embedding_layer=True, freeze_recurrent_kernel=True)

<ipython-input-18-62788170b303> in build_model(vocab_size, embedding_dim, rnn_units, batch_size, freeze_embedding_layer, freeze_recurrent_kernel)
     15       print("rnn type:",type(model.layers[1]))
     16       print("rnn recurrent kernel type:", type(model.layers[1].recurrent_kernel))
---> 17       model.layers[1].recurrent_kernel.trainable = False
     18 
     19 

AttributeError: can't set attribute

0 个答案:

没有答案