如何使用Saver对象在TensorFlow中保存变量字典?

时间:2016-09-28 07:33:03

标签: tensorflow

我在TensorFlow中构建了一个简单的4层神经网络(2个隐藏层)。我没有使用TensorFlow提供的内置NN,而是实现了我自己的基本版本。现在,为了保持W(权重)和B(偏见)Tensor在一个地方,我建立了这些变量的字典,如下所示:

weights = {
    'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
biases = {
    'b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'b2': tf.Variable(tf.random_normal([n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}

在我学习这些参数后,我想使用Saver对象保存它们。我试过这个:

saver = tf.train.Saver([weights,biases])
save_path = saver.save(sess,"./data/model.ckpt")

但我没有成功。看到的错误是:

TypeError: unhashable type: 'dict'

现在,一个解决方案是将所有变量:h1,h2,out,b1,b2,bias_out(偏离字典)分成单个变量并保存它们但这似乎是一种天真的方法。如果我后来有更多的变量需要联合起来,我想保持这种方式,它更干净,更易于管理。如何将分组变量保存在一起?

1 个答案:

答案 0 :(得分:2)

Tensorflow Saver不接受字典列表。也许你应该尝试merge your dictionaries first

parameters = weights.copy()
parameters.update(bias)

或(使用Python 3.5)

parameters = {**weights,**bias}

之后:

saver = tf.train.Saver(parameters)
save_path = saver.save(sess,"./data/model.ckpt")

另一种解决方案:

saver = tf.train.Saver({name:variable for name,variable in weights.items()+bias.items()})
save_path = saver.save(sess,"./data/model.ckpt")

最后一个解决方案可能会遇到一些问题,例如" out"作为权重和偏见的关键,所以似乎只会保存一个[" out":变量]。