无法加载张量流(tf-agent)保存的模型

时间:2019-06-11 00:35:32

标签: python tensorflow

我正在使用以下代码创建tf代理DqnAgent:

tf_agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
    train_step_counter=train_step_counter

在训练循环中,我将使用此模型保存

tf.saved_model.save(tf_agent, saved_models_path)

经过培训后,我想用以下方式加载保存的模型

if tf.saved_model.contains_saved_model(saved_models_path):
    tf_agent = tf.saved_model.load(saved_models_path)

仅当saved_path中的文件夹包含一个,函数contains_saved_model(saved_models_path)返回True时,此代码才会加载已保存的模型,因此已加载模型,但有例外,并且程序崩溃:

Traceback (most recent call last):
    File "/home/claudino/Projetos/dino-tf-agents/dino_ia/model/agent.py", line 50, in <module>
        tf_agent = tf.saved_model.load(saved_models_path)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 408, in load
        return load_internal(export_dir, tags)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 432, in load_internal
        export_dir)
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 58, in __init__
        self._load_all()
    File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 168, in _load_all
        slot_variable = optimizer_object.add_slot(
    AttributeError: '_UserObject' object has no attribute 'add_slot'

    Process finished with exit code 1

我浏览了tensorflow代码,但找不到问题。有人可以帮助我吗?

我之所以使用tf-agents-nightly,是因为Google的源代码无法在tf-agents“稳定”版本上运行(我不确定tf-agents是否真的稳定),并尝试使用{{ 1}} 1.3和tensorflow,会出现相同的问题。

0 个答案:

没有答案