我正在处理课程任务,我必须在keras中保存和加载模型。我创建模型,训练模型和保存模型的代码是
def get_new_model(input_shape):
"""
This function should build a Sequential model according to the above specification. Ensure the
weights are initialised by providing the input_shape argument in the first layer, given by the
function argument.
Your function should also compile the model with the Adam optimiser, sparse categorical cross
entropy loss function, and a single accuracy metric.
"""
model = Sequential([
Conv2D(16, kernel_size=(3,3),activation='relu',padding='Same', name='conv_1', input_shape=input_shape),
Conv2D(8, kernel_size=(3,3), activation='relu', padding='Same', name='conv_2'),
MaxPooling2D(pool_size=(8,8), name='pool_1'),
tf.keras.layers.Flatten(name='flatten'),
Dense(32, activation='relu', name='dense_1'),
Dense(10, activation='softmax', name='dense_2')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])
return model
model = get_new_model(x_train[0].shape)
def get_checkpoint_every_epoch():
"""
This function should return a ModelCheckpoint object that:
- saves the weights only at the end of every epoch
- saves into a directory called 'checkpoints_every_epoch' inside the current working directory
- generates filenames in that directory like 'checkpoint_XXX' where
XXX is the epoch number formatted to have three digits, e.g. 001, 002, 003, etc.
"""
path = 'checkpoints_every_epoch/checkpoint_{epoch:02d}'
checkpoint = ModelCheckpoint(filepath = path, save_weights_only=True, save_freq= 'epoch')
return checkpoint
def get_checkpoint_best_only():
"""
This function should return a ModelCheckpoint object that:
- saves only the weights that generate the highest validation (testing) accuracy
- saves into a directory called 'checkpoints_best_only' inside the current working directory
- generates a file called 'checkpoints_best_only/checkpoint'
"""
path = 'checkpoints_best_only/checkpoint'
checkpoint = ModelCheckpoint(filepath = path, save_best_only=True, save_weights_only=True, monitor='val_acc')
return checkpoint
def get_early_stopping():
"""
This function should return an EarlyStopping callback that stops training when
the validation (testing) accuracy has not improved in the last 3 epochs.
HINT: use the EarlyStopping callback with the correct 'monitor' and 'patience'
"""
return EarlyStopping(monitor= 'val_acc', patience=3)
checkpoint_every_epoch = get_checkpoint_every_epoch()
checkpoint_best_only = get_checkpoint_best_only()
early_stopping = get_early_stopping()
callbacks = [checkpoint_every_epoch, checkpoint_best_only, early_stopping]
model.fit(x_train, y_train, epochs=50, validation_data=(x_test, y_test), callbacks=callbacks)
在这里,我将每个时期的权重保存在checkpoints_every_epoch/checkpoint_{epoch:02d}
中,将最佳权重保存在checkpoints_best_only/checkpoint
中。现在,当我想同时加载这两个代码时,
def get_model_last_epoch(model):
"""
This function should create a new instance of the CNN you created earlier,
load on the weights from the last training epoch, and return this model.
"""
filepath = tf.train.latest_checkpoint('checkpoint_every_epoch')
model.load_weights(filepath)
return model
def get_model_best_epoch(model):
"""
This function should create a new instance of the CNN you created earlier, load
on the weights leading to the highest validation accuracy, and return this model.
"""
filepath = tf.train.latest_checkpoint('checkpoint_best_only')
model.load_weights(filepath)
return model
model_last_epoch = get_model_last_epoch(get_new_model(x_train[0].shape))
model_best_epoch = get_model_best_epoch(get_new_model(x_train[0].shape))
print('Model with last epoch weights:')
get_test_accuracy(model_last_epoch, x_test, y_test)
print('')
print('Model with best epoch weights:')
get_test_accuracy(model_best_epoch, x_test, y_test)
我得到的错误是
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-18-b6d169507ca4> in <module>
3 # Verify that the second has a higher validation (testing) accuarcy.
4
----> 5 model_last_epoch = get_model_last_epoch(get_new_model(x_train[0].shape))
6 model_best_epoch = get_model_best_epoch(get_new_model(x_train[0].shape))
7 print('Model with last epoch weights:')
<ipython-input-15-6f7ff0c732b4> in get_model_last_epoch(model)
10 """
11 filepath = tf.train.latest_checkpoint('checkpoint_every_epoch')
---> 12 model.load_weights(filepath)
13 return model
14
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in load_weights(self, filepath, by_name)
179 raise ValueError('Load weights is not yet supported with TPUStrategy '
180 'with steps_per_run greater than 1.')
--> 181 return super(Model, self).load_weights(filepath, by_name)
182
183 @trackable.no_automatic_dependency_tracking
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in load_weights(self, filepath, by_name)
1137 format.
1138 """
-> 1139 if _is_hdf5_filepath(filepath):
1140 save_format = 'h5'
1141 else:
/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _is_hdf5_filepath(filepath)
1447
1448 def _is_hdf5_filepath(filepath):
-> 1449 return (filepath.endswith('.h5') or filepath.endswith('.keras') or
1450 filepath.endswith('.hdf5'))
1451
AttributeError: 'NoneType' object has no attribute 'endswith'
我能知道我的代码有什么问题吗,或者如何改进它并消除错误。
编辑:
如果我在不使用函数tf.train.latest_checkpoint
来获取最后一个文件名的情况下在单个模型上执行此操作,则它将起作用。那是
dummyModel.load_weights('checkpoints_every_epoch/checkpoint_23')
print('Model with last epoch weights:')
get_test_accuracy(dummyModel, x_test, y_test)
print('')
答案 0 :(得分:1)
我明白了。文件路径名错误。我花了很多时间弄清楚。所以正确的功能是
def get_model_last_epoch(model):
"""
This function should create a new instance of the CNN you created earlier,
load on the weights from the last training epoch, and return this model.
"""
model.load_weights(tf.train.latest_checkpoint('checkpoints_every_epoch'))
return model
def get_model_best_epoch(model):
"""
This function should create a new instance of the CNN you created earlier, load
on the weights leading to the highest validation accuracy, and return this model.
"""
#filepath = tf.train.latest_checkpoint('checkpoints_best_only')
model.load_weights(tf.train.latest_checkpoint('checkpoints_best_only'))
return model
并且不会出现错误,因为tf.train.latest_checkpoint
中的文件名正确