如何使用TensorFlow在Returnn中加载经过训练的网络的权重

时间:2018-06-26 13:10:35

标签: python tensorflow returnn

当尝试在训练有素的多个时间段上加载保存的权重时 使用以下代码返回网络:

import tensorflow as tf
from returnn.Config import Config
from returnn.TFNetwork import TFNetwork

for i in range(1,11):
    modelFilePath = path/to/model/ + 'network.' + '%03d' % (i,)

    returnnConfig = Config()
    returnnConfig.load_file(path/to/configFile)
    returnnTfNetwork = TFNetwork(config=path/to/configFile, train_flag=False, eval_flag=True)

    returnnTfNetwork.construct_from_dict(returnnConfig.typed_value('network'))

    with tf.Session() as sess:
        returnnTfNetwork.load_params_from_file(modelFilePath, sess)

我收到以下错误:

Variables to restore which are not in checkpoint:
global_step_1

Variables in checkpoint which are not needed for restore:
global_step

Probably we can restore these:
(None)

Error, some entry is missing in the checkpoint

1 个答案:

答案 0 :(得分:1)

问题是您每次在循环中都重新创建TFNetwork,并且在那里每次全局步骤也都会创建一个新变量,由于每个变量必须具有唯一的名称,因此必须将其称为不同。

您可以在循环中执行以下操作:

tf.reset_default_graph()