我目前正在使用GPR使用Gpflow实现算法。我想在GPR训练后保存参数并加载模型以进行测试。有人知道命令吗?
答案 0 :(得分:3)
GPflow现在有一个包含tips & tricks的页面。您可以点击链接以找到问题的答案。但是,我还要在此处粘贴MWE:
假设您要存储GPR模型,可以使用gpflow.Saver()
来完成:
kernel = gpflow.kernels.RBF(1)
x = np.random.randn(100, 1)
y = np.random.randn(100, 1)
model = gpflow.models.GPR(x, y, kernel)
filename = "/tmp/gpr.gpflow"
path = Path(filename)
if path.exists():
path.unlink()
saver = gpflow.saver.Saver()
saver.save(filename, model)
要将其加载回去,您必须使用以下解决方案之一:
with tf.Graph().as_default() as graph, tf.Session().as_default():
model_copy = saver.load(filename)
或者如果您想在之前存储模型的同一会话中加载模型,则需要应用一些技巧:
ctx_for_loading = gpflow.saver.SaverContext(autocompile=False)
model_copy = saver.load(filename, context=ctx_for_loading)
model_copy.clear()
model_copy.compile()
答案 1 :(得分:0)
我为gpflow模型采用的一个选项是仅保存和加载可训练对象。假定您具有构建和编译模型的功能。 通过将变量保存到hdf5文件中,我在下面展示了这一点。
import h5py
def _load_model(model, load_file):
"""
Load a model given by model path
"""
vars = {}
def _gather(name, obj):
if isinstance(obj, h5py.Dataset):
vars[name] = obj[...]
with h5py.File(load_file) as f:
f.visititems(_gather)
model.assign(vars)
def _save_model(model, save_file):
vars = model.read_trainables()
with h5py.File(save_file) as f:
for name, value in vars.items():
f[name] = value