Tensorflow MultiRNNCell保存和还原

时间:2018-07-10 07:37:23

标签: tensorflow deep-learning load rnn

使用MultiRNNCell保存或恢复模型似乎无法正常工作。

我正在使用以下代码来处理分类问题;

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

然后我使用rnn_cell(堆叠)来训练并保存模型。

就我而言,我训练了两个模型:一个模型的num_layers = 2,另一个模型的num_layer = 3。

然后,我首先运行上面的代码,然后执行恢复过程以将权重值替换为上面的变量。

似乎它只加载rnn_cell的第一层,因为num_layers = 1加载的结果与模型num_layers = 2或num_layers = 3的结果完全相同。

模型本身加载得很好,所以我只能认为它没有正确保存或加载。

===== 编辑:我加载时没有任何匹配的模型,并使用下面的代码查看

[<tf.Variable 'global_step:0' shape=() dtype=int32_ref>,
 <tf.Variable 'embedding_layer/w:0' shape=(11441, 200) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w1:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w2:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w3:0' shape=(400, 1) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/b2:0' shape=(400,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/concat_w:0' shape=(800, 400) dtype=float32_ref>,
 <tf.Variable 'logits/w:0' shape=(800, 7) dtype=float32_ref>,
 <tf.Variable 'logits/b:0' shape=(7,) dtype=float32_ref>,
 <tf.Variable 'train_optimizer/beta1_power:0' shape=() dtype=float32_ref>,
 <tf.Variable 'train_optimizer/beta2_power:0' shape=() dtype=float32_ref>,
 <tf.Variable 'embedding_layer/w/Adam:0' shape=(11441, 200) dtype=float32_ref>,
 <tf.Variable 'embedding_layer/w/Adam_1:0' shape=(11441, 200) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w1/Adam:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w1/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w2/Adam:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w2/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w3/Adam:0' shape=(400, 1) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/w3/Adam_1:0' shape=(400, 1) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/b2/Adam:0' shape=(400,) dtype=float32_ref>,
 <tf.Variable 'dynamic_rnn/attn/b2/Adam_1:0' shape=(400,) dtype=float32_ref>,
 <tf.Variable 'logits/w/Adam:0' shape=(800, 7) dtype=float32_ref>,
 <tf.Variable 'logits/w/Adam_1:0' shape=(800, 7) dtype=float32_ref>,
 <tf.Variable 'logits/b/Adam:0' shape=(7,) dtype=float32_ref>,
 <tf.Variable 'logits/b/Adam_1:0' shape=(7,) dtype=float32_ref>]

然后输出为

{{1}}

表示它使用rnn单元的三个隐藏层正确存储了预期的目标。但是似乎权重无法与具有此多个隐藏层的模型自动匹配。

=====

我试图使用deep rnn查找模型的保存和恢复,但是找不到任何内容,因此我要求在这里寻求帮助。

有人对此有相同的问题和解决方案吗?

1 个答案:

答案 0 :(得分:0)

您可以将“ tf.contrib.rnn.MultiRNNCell”更改为“ tf.nn.rnn_cell.MultiRNNCell”