如何在TensorFlow中恢复多个神经网络模型?

时间:2018-05-02 13:53:42

标签: python tensorflow neural-network save restore

我正在设计一个带有3个简单前馈NN的集合神经网络。现在我面临着恢复这3个神经网络以进行测试的问题。到目前为止,saver函数创建并保存了3个NN模型。

saver = tf.train.Saver()    
saver.save(sess, save_path=get_save_path(i), global_step=1000)

我已成功将它们保存为“.checkpoint”,“。meta”,“。index”和“.data文件,如下所示。

enter image description here

我尝试使用此编码恢复它们:

 saver = tf.train.import_meta_graph(get_save_path(i) + '-1000.meta')
 saver.restore(sess,tf.train.latest_checkpoint(save_dir))

但它只恢复了第三个NN network2进行测试。它影响了我的结果,因为算法只需要1个模型(network2)并假设所有三个NN模型在集合函数中是相同的。

供参考:

我理想的合奏功能:

ensemble = (network0 + network1 + network2) / 3

真实结果:

ensemble = (network2 + network2 + network2) / 3

如何让TF一起恢复所有3个NN模型?

1 个答案:

答案 0 :(得分:0)

我认为你搞混了。但我先回答一下这个问题:

您需要在不同范围内多次创建模型。那么应该可以平均那些变量。

假设您通过

创建3个网络
import tensorflow as tf

# save 3 version
for i in range(3):
    tf.reset_default_graph()

    a = tf.get_variable('test', [1])

    assign_op = a.assign([i])

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(assign_op)
        print a.name, sess.run(a)

        saver = tf.train.Saver(tf.global_variables())
        saver.save(sess, './model/version_%i' % i)

此处每个网络都具有相同的图形结构,并且只包含单个参数/权重名称“test”。

然后你需要多次创建相同的图表,但是在不同的 variable_scopes下,比如

# load all versions in different scopes
tf.reset_default_graph()

a_collection = []

for i in range(3):
    # use different var-scopes
    with tf.variable_scope('scope_%0i' % i):
        # create same network
        a = tf.get_variable('test', [1])
        a_collection.append(a)

现在,每个恢复程序都需要知道应该使用哪个范围或变量名称映射。这可以通过

来完成
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print zip(sess.run(a_collection), [n.name for n in a_collection])

    for i in range(3):
        loader = tf.train.Saver({"test": a_collection[i]})
        loader = loader.restore(sess, './model/version_%i' % i)

    print sess.run(a_collection)

哪个会给你

 [array([0.], dtype=float32), array([1.], dtype=float32), array([2.], dtype=float32)]

正如所料。现在,你可以用你的模型做任何你想做的事。

但这是整体预测的工作方式!在整体模型中,您通常只对预测进行平均。因此,您可以使用不同的模型多次运行脚本,然后对预测进行平均。

如果你真的想平均模型的权重,可以考虑使用numpy将权重转换为python-dict。