我无法在测试时从检查点恢复训练好的模型的批次归一化层。
在还原后检查批归一化层的参数,似乎它们未正确加载,但是具有其初始化值(gammas = 1,betas = 0)。这在测试期间给我不好的结果。培训和验证工作按预期进行,并且批次标准化参数得到更新,因此我认为问题必须出在保存和/或恢复模型上。
使用子类API定义模型:
class convUnit(tf.keras.layers.Layer):
def __init__(self,numFilters):
super(convUnit, self).__init__()
self.conv = tf.keras.layers.Conv2D(numFilters,kernel_size=3,padding='same')
self.bn = tf.keras.layers.BatchNormalization(trainable=True)
self.relu = tf.keras.layers.Activation('relu')
def call(self,inputs):
x = self.conv(inputs)
x = self.bn(x)
x = self.relu(x)
return x
class Network(tf.keras.Model): #### USING SUBCLASSING API
def __init__(self):
super(Network, self).__init__()
self.conv1 = convUnit(32)
self.mp1 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')
self.conv2 = convUnit(32)
self.mp2 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')
self.conv3 = convUnit(32)
self.mp3 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')
self.conv4 = convUnit(32)
self.mp4 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')
self.flat = tf.keras.layers.Flatten()
self.drop1 = tf.keras.layers.Dropout(rate = 0.5)
self.dense1 = tf.keras.layers.Dense(units = 128)
def call(self, inputs,training):
x = self.conv1(inputs)
x = self.mp1(x)
x = self.conv2(x)
x = self.mp2(x)
x = self.conv3(x)
x = self.mp3(x)
x = self.conv4(x)
x = self.mp4(x)
x = self.flat(x)
x = self.drop1(x,training = training) ###### training argument passing to drop object
x = self.dense1(x)
x = tf.math.l2_normalize(x,axis=1)
return x
要保存我正在使用的模型:
model = Network()
modelCheckpoint = tf.train.Checkpoint(optimizer=optimizer, model=model,optimizer_step=tf.train.get_or_create_global_step())
lastModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=lastModelDir, max_to_keep=1)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)
...
if ep_mean_val_loss < best_val_loss:
######## BEST MODEL CHECKPOINT WRITTING
bestModelSaver.save()
并用于恢复模型:
model = Network()
modelCheckpoint = tf.train.Checkpoint(model=model)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)
modelCheckpoint.restore(bestModelSaver.latest_checkpoint) ##### restore last model latest checkpoint
我在急切的执行模式下使用tensorflow 1.13。 我在使用模型子类化API和tensorflow中急切的执行模式方面还很陌生,所以也许我错过了一些东西。
任何帮助将不胜感激。