使用MultiRNNCell保存或恢复模型似乎无法正常工作。
我正在使用以下代码来处理分类问题;
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
然后我使用rnn_cell(堆叠)来训练并保存模型。
就我而言,我训练了两个模型:一个模型的num_layers = 2,另一个模型的num_layer = 3。
然后,我首先运行上面的代码,然后执行恢复过程以将权重值替换为上面的变量。
似乎它只加载rnn_cell的第一层,因为num_layers = 1加载的结果与模型num_layers = 2或num_layers = 3的结果完全相同。
模型本身加载得很好,所以我只能认为它没有正确保存或加载。
===== 编辑:我加载时没有任何匹配的模型,并使用下面的代码查看
[<tf.Variable 'global_step:0' shape=() dtype=int32_ref>,
<tf.Variable 'embedding_layer/w:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/concat_w:0' shape=(800, 400) dtype=float32_ref>,
<tf.Variable 'logits/w:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/b:0' shape=(7,) dtype=float32_ref>,
<tf.Variable 'train_optimizer/beta1_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'train_optimizer/beta2_power:0' shape=() dtype=float32_ref>,
<tf.Variable 'embedding_layer/w/Adam:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'embedding_layer/w/Adam_1:0' shape=(11441, 200) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/fw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/kernel/Adam_1:0' shape=(400, 800) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/bw/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1:0' shape=(800,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1/Adam:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w1/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2/Adam:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w2/Adam_1:0' shape=(400, 400) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3/Adam:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/w3/Adam_1:0' shape=(400, 1) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2/Adam:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'dynamic_rnn/attn/b2/Adam_1:0' shape=(400,) dtype=float32_ref>,
<tf.Variable 'logits/w/Adam:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/w/Adam_1:0' shape=(800, 7) dtype=float32_ref>,
<tf.Variable 'logits/b/Adam:0' shape=(7,) dtype=float32_ref>,
<tf.Variable 'logits/b/Adam_1:0' shape=(7,) dtype=float32_ref>]
然后输出为
{{1}}
表示它使用rnn单元的三个隐藏层正确存储了预期的目标。但是似乎权重无法与具有此多个隐藏层的模型自动匹配。
=====
我试图使用deep rnn查找模型的保存和恢复,但是找不到任何内容,因此我要求在这里寻求帮助。
有人对此有相同的问题和解决方案吗?
答案 0 :(得分:0)
您可以将“ tf.contrib.rnn.MultiRNNCell”更改为“ tf.nn.rnn_cell.MultiRNNCell”