MultiRNN不使用相同的BasicLSTM单元列表

时间:2018-06-01 13:52:22

标签: python tensorflow machine-learning lstm recurrent-neural-network

以下代码在使用相同的基本单元格(cell1, cell1) MultiRNNCell时失败:

import tensorflow as tf
cell1 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False, name = "cell1")
cell2 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False,name = "cell2")
multi = tf.contrib.rnn.MultiRNNCell([cell1, cell1] )
init = multi.zero_state(64, tf.float32)
output,state = multi(tf.ones([64,512]),init)

此代码与(cell1, cell2)一起使用的位置。但cell2cell1相同:

import tensorflow as tf
cell1 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False, name = "cell1")
cell2 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False,name = "cell2")
multi = tf.contrib.rnn.MultiRNNCell([cell1, cell2] )
init = multi.zero_state(64, tf.float32)
output,state = multi(tf.ones([64,512]),init)

我可以知道两个代码示例的区别吗?

错误是这样的:

  

ValueError:尺寸必须相等,但对于输入形状为'multi_rnn_cell / cell_0 / cell1 / MatMul_1'(op:'MatMul')的尺寸必须为256和640:[64,256],[640,512]。

1 个答案:

答案 0 :(得分:1)

这是一个已知的限制(例如讨论here)。问题是每个单元实例都为权重创建一个内部变量。此变量的维度由隐藏大小(在您的情况下为128)和此单元实例接收的输入大小(512)确定。当您多次使用相同的单元格时,必须确保输入在所有情况下都相同。

考虑您的示例代码:

import tensorflow as tf
cell1 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False, name = "cell1")
cell2 = tf.contrib.rnn.BasicLSTMCell(128,reuse=False,name = "cell2")
multi = tf.contrib.rnn.MultiRNNCell([cell1, cell1] )
init = multi.zero_state(64, tf.float32)
output,state = multi(tf.ones([64,512]),init)

multi中两个单元格的输入将为[..., 640][..., 256],因为640=512+128(单元格接收来自前一个单元格以及来自前一单元格的输入输入序列)。因此,其中的权重矩阵将是 [640, 512] [256, 512] 512这里实际上是128*4,而不是输入大小)。

但是你正在使用相同的单元格实例! Tensorflow尝试将已有的矩阵与新输入匹配并失败。另一方面,当您使用不同的实例时,tensorflow能够为不同的层实例化不同的矩阵并正确地计算出形状。