TensorFlow检查点文件中变量的命名逻辑

时间:2020-06-03 04:01:30

标签: python tensorflow tensorflow2.0

我已经定义了一个简单的模型,如下所示:

def create_model():
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Embedding(vocab_size, embedding_size, name='my_embedding'),
            tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, name='my_lstm'), name='my_bidir'),
            tf.keras.layers.Dense(64, activation="relu", name='my_hidden'),
            tf.keras.layers.Dense(1, name='my_final'),
        ]
    )
    model.compile(
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.Adam(1e-4),
        metrics=["accuracy"],
    )
    return model

训练后我得到了一些ckpt文件。当我检查这些文件时,得到的变量名称如下:

variables in ckpt files

但是当我使用model.trainable_weights来获取变量时,我得到了不同的结果:

enter image description here

这让我感到困惑。

  1. 为什么它们不同?是什么导致了这种差异?
  2. TF2中的命名逻辑是什么?

顺便说一句,我使用的是tensorflow 2.2.0。

0 个答案:

没有答案