无法从检查点正确还原tf.keras.layers.BatchNormalization图层

时间:2019-05-14 11:15:12

标签: restore batch-normalization checkpoint tf.keras eager-execution

我无法在测试时从检查点恢复训练好的模型的批次归一化层。

在还原后检查批归一化层的参数,似乎它们未正确加载,但是具有其初始化值(gammas = 1,betas = 0)。这在测试期间给我不好的结果。培训和验证工作按预期进行,并且批次标准化参数得到更新,因此我认为问题必须出在保存和/或恢复模型上。

使用子类API定义模型:

class convUnit(tf.keras.layers.Layer):
    def __init__(self,numFilters):
        super(convUnit, self).__init__()
        self.conv = tf.keras.layers.Conv2D(numFilters,kernel_size=3,padding='same')
        self.bn = tf.keras.layers.BatchNormalization(trainable=True) 
        self.relu = tf.keras.layers.Activation('relu')

    def call(self,inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Network(tf.keras.Model):  #### USING SUBCLASSING API
     def __init__(self):
          super(Network, self).__init__()
          self.conv1 = convUnit(32)
          self.mp1 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv2 = convUnit(32)
          self.mp2 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv3 = convUnit(32)
          self.mp3 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv4 = convUnit(32)
          self.mp4 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.flat = tf.keras.layers.Flatten()
          self.drop1 = tf.keras.layers.Dropout(rate = 0.5)
          self.dense1 = tf.keras.layers.Dense(units = 128)


     def call(self, inputs,training):
          x = self.conv1(inputs)
          x = self.mp1(x)

          x = self.conv2(x)
          x = self.mp2(x)

          x = self.conv3(x)
          x = self.mp3(x)

          x = self.conv4(x)
          x = self.mp4(x)

          x = self.flat(x)
          x = self.drop1(x,training = training)  ###### training argument passing to drop object
          x = self.dense1(x)
          x = tf.math.l2_normalize(x,axis=1)
          return x

要保存我正在使用的模型:

model = Network()
modelCheckpoint = tf.train.Checkpoint(optimizer=optimizer, model=model,optimizer_step=tf.train.get_or_create_global_step())
lastModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=lastModelDir, max_to_keep=1)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)

...

if ep_mean_val_loss < best_val_loss:
        ######## BEST MODEL CHECKPOINT WRITTING
        bestModelSaver.save()

并用于恢复模型:

model = Network()
modelCheckpoint = tf.train.Checkpoint(model=model)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)

modelCheckpoint.restore(bestModelSaver.latest_checkpoint) ##### restore last model latest checkpoint

我在急切的执行模式下使用tensorflow 1.13。 我在使用模型子类化API和tensorflow中急切的执行模式方面还很陌生,所以也许我错过了一些东西。

任何帮助将不胜感激。

0 个答案:

没有答案