如何保存JAX训练模型的优化器状态?

时间:2020-10-27 08:43:07

标签: python machine-learning jax

我正在玩mnist_vae示例,但无法弄清楚如何正确保存/加载训练模型的权重。

enc_init_rng, dec_init_rng = random.split(random.PRNGKey(2))
_, init_encoder_params = encoder_init(enc_init_rng, (batch_size, 28 * 28))
_, init_decoder_params = decoder_init(dec_init_rng, (batch_size, 10))
init_params = init_encoder_params, init_decoder_params

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=0.9)
opt_state = opt_init(init_params)

之后,我使用opt_update训练模型并希望保存它。但是,我没有找到任何将优化器状态保存到磁盘的功能。

我尝试保存参数并使用它们初始化opt_state,但并非所有信息都保存下来,结果opt_state_1不是原始的opt_state。

weights=get_params(opt_state)  
jnp.save(file, weights)  
weights = jnp.load(file,allow_pickle=True)  
opt_state_1 = opt_init(init_params)

如何正确保存我训练的模型?

1 个答案:

答案 0 :(得分:0)

import pickle
from jax.experimental import optimizers

trained_params = optimizers.unpack_optimizer_state(opt_state)
pickle.dump(trained_params, open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "wb"))

best_params = pickle.load(open(os.path.join(config["ckpt_path"], "best_ckpt.pkl"), "rb"))
best_opt_state = optimizers.pack_optimizer_state(best_params)