如何从vanila Tensorflow中的LSTM细胞中提取所有重量?

时间:2017-06-21 13:48:45

标签: tensorflow

我是火车LSTM网络

cell_fw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE)
cell_bw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE)

rnn_outputs, final_state_fw, final_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
    cell_fw=cell_fw,
    cell_bw=cell_bw,
    inputs=rnn_inputs,
    dtype=tf.float32
)

此外,我尝试保存系数:

d = {}
with tf.Session() as sess:
    # train code ...
    variables_names =[v.name for v in tf.global_variables()]
    values = sess.run(variables_names)
    for k,v in zip(variables_names, values):
        d[k] = v

字典d每个LSTM单元只有2个对象:

[(k,v.shape) for (k,v) in sorted(d.items(), key=lambda x:x[0])]
[('bidirectional_rnn/bw/basic_lstm_cell/biases:0', (1024,)),
 ('bidirectional_rnn/bw/basic_lstm_cell/weights:0', (272, 1024)),
 ('bidirectional_rnn/fw/basic_lstm_cell/biases:0', (1024,)),
 ('bidirectional_rnn/fw/basic_lstm_cell/weights:0', (272, 1024)),
 ('char_embedding:0', (70, 16)),
 ('softmax_biases:0', (5068,)),
 ('softmax_weights:0', (5068, 512))]
我很困惑。每个LSTM单元应包含多达4个可训练层?如果是这样,如何从LSTM-cell获得所有权重?

1 个答案:

答案 0 :(得分:1)

LSTM单元格的4个权重(和偏差)存储为单个张量,其中沿第二个轴的切片对应于不同类型的权重(在门中,忘记门,ecc)

例如,我想在你的情况下HIDDEN_SIZE的值是256

要访问不同的部分,您应该沿着长度为1024的轴切割张量(但我不知道存储不同类型的权重的顺序...)