rllib-从检查点获取TensorFlow或PyTorch模型输出

时间:2020-08-23 14:28:53

标签: tensorflow pytorch ray rllib

我想在其他代码中使用rllib训练的策略模型,在该代码中,我需要跟踪针对特定输入状态生成的操作。使用标准的TensorFlow或PyTorch(首选)网络模型可以提供这种灵活性,但是我找不到关于如何从经过训练的rllib代理中生成可用的dat或H5文件的清晰文档,然后将其加载到Torch或tf / Keras中模型。

1 个答案:

答案 0 :(得分:0)

从检查点获取权重的最简单方法是使用 rllib 再次加载它,然后使用 Tensorflow/Pytorch 命令保存它。 如果你有一个 keras TF 模型,你可以简单地调用:

model.save('my_model.h5') # creates a HDF5 file