使用tf.train.Saver加载模型检查点时如何修改张量形状?

时间:2017-04-10 07:39:00

标签: python tensorflow

我培训了一个固定批量大小的RNN,但现在我想修改我用tf.train.Saver保存的图表,使批量大小为1进行推理。我怎么能这样做?

session = tf.InteractiveSession()
saver = tf.train.import_meta_graph('model.ckpt.meta')
saver.restore(session, 'model.ckpt')

1 个答案:

答案 0 :(得分:0)

实现此目的的一种方法是在测试时重建不同的(尽管是兼容的)网络,并将恢复仅限制为权重。

在培训期间,

net = make_my_net(batch_size)
...
saver.save(session, model_name)

在测试期间,

net = make_my_net(1)
...
saver.restore(session, model_name)

后者会将变量(包括网络权重)的值替换为先前保存的值。您不必根据documentation初始化您要覆盖的变量,尽管我认为并非总是如此。

请注意,重建不同的网络可让您有机会构建更清洁的测试网络,例如:删除诸如dropout之类的图层。