如何将权重转移到tensorflow RNN小区

时间:2016-12-23 05:29:58

标签: tensorflow

我在matlab中实现了一组训练模型的权重。我想将权重移植到tensorflow。然而,具有500个单元的tf.rnn.rnn_cell.LSTMCell具有形状的重量矩阵(1524,2000)。为什么1524?为何2000?这根本不符合我的权重。

我的模型有3个隐藏层,每个隐藏层有1000个节点,最后一个隐藏层是时间层(Recurrent)。输入维度为1539.输出维度为1026。 时间层具有1x1000时间加权和1000x1000层加权以及1x1000偏差。

import scipy
import numpy as np
import tensorflow as tf

x = tf.placeholder(shape=[None, 3, 1024], dtype=tf.float32, name='input')
cell = tf.nn.rnn_cell.LSTMCell(500)
output, state = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    var = [x for x in tf.trainable_variables() if x.name=='RNN/LSTMCell/W_0:0']
    val = sess.run(var)
    # why 1524x2000?
    print(val[0].shape)

2 个答案:

答案 0 :(得分:1)

看起来您正在使用x作为输入调用您的单元格。权重矩阵的大小为(500 + 1024)x(4 * 500)。 LSTM具有四个门控功能,但出于效率原因,它们的各个矩阵连接在一起。这就是矩阵的第二个维度是4 * 500的原因。第一个维度是单元格的大小加上输入的大小,因为它需要与输入连接的输入与上一步的输出相乘。

答案 1 :(得分:0)

首先将预先训练的权重转换为张量,这可以通过阅读包含权重的文件并使用weights = tf.convert_to_tensor(your_weights)

来完成

然后找到你的lstm单元格权重的名称(这可以通过列出此图中的所有操作来完成)

最后使用tf.train.import_meta_graph(meta_graph_def=your_meta_graph_def, input_map={"lstm-cell-weights-name": weights})(最重要的是input_map参数)

仅适用于0.12之后的版本。

因为旧版本不支持input_map参数