我正在研究一个具有2000个神经元和8个以上恒定输入神经元的1层隐藏神经网络,以解决回归问题。
特别是,作为优化程序,我正在使用学习参数= 0.001的RMSprop,从输入到隐藏层的ReLU激活,从隐藏到输出的线性化。我还使用了一个小批量梯度下降(32个观测值)并运行了2000次模型,即epochs =2000。
我的目标是,经过培训,从2000年的最佳神经网络中提取权重(经过多次试验,最好的永远不会是最后的试验,而最好的意思是导致最小的MSE)。
使用save_weights('my_model_2.h5',save_format ='h5')确实有效,但是据我了解,它从最后一个时期中提取权重,而我希望从NN表现最好的那个时期中提取权重。请找到我编写的代码:
def build_first_NN():
model = keras.Sequential([
layers.Dense(2000, activation=tf.nn.relu, input_shape=[len(X_34.keys())]),
layers.Dense(1)
])
optimizer = tf.keras.optimizers.RMSprop(0.001)
model.compile(loss='mean_squared_error',
optimizer=optimizer,
metrics=['mean_absolute_error', 'mean_squared_error']
)
return model
first_NN = build_first_NN()
history_firstNN_all_nocv = first_NN.fit(X_34,
y_34,
epochs = 2000)
first_NN.save_weights('my_model_2.h5', save_format='h5')
trained_weights_path = 'C:/Users/Myname/Desktop/otherfolder/Data/my_model_2.h5'
trained_weights = h5py.File(trained_weights_path, 'r')
weights_0 = pd.DataFrame(trained_weights['dense/dense/kernel:0'][:])
weights_1 = pd.DataFrame(trained_weights['dense_1/dense_1/kernel:0'][:])
然后提取的权重应该是2000个时期中最后一个的权重:我如何才能从MSE最小的那个获得权重?
期待任何评论。
编辑:已解决
基于所收到的建议(对于一般的兴趣),这就是我更新代码并满足我的范围的方式:
# build_first_NN() as defined before
first_NN = build_first_NN()
trained_weights_path = 'C:/Users/Myname/Desktop/otherfolder/Data/my_model_2.h5'
checkpoint = ModelCheckpoint(trained_weights_path,
monitor='mean_squared_error',
verbose=1,
save_best_only=True,
mode='min')
history_firstNN_all_nocv = first_NN.fit(X_34,
y_34,
epochs = 2000,
callbacks = [checkpoint])
trained_weights = h5py.File(trained_weights_path, 'r')
weights_0 = pd.DataFrame(trained_weights['model_weights/dense/dense/kernel:0'][:])
weights_1 = pd.DataFrame(trained_weights['model_weights/dense_1/dense_1/kernel:0'][:])
答案 0 :(得分:1)
使用Keras的ModelCheckpoint
回调。
from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath, monitor='val_mean_squared_error', verbose=1, save_best_only=True, mode='max')
将此用作您model.fit()
中的回调。这将始终以最高的验证准确性(验证时的最低MSE)将模型保存在filepath
指定的位置。
您可以找到文档here。 当然,您在培训期间需要为此提供验证数据。否则,我认为您可以自己编写回调函数来查看最低培训的MSE。