如何在TensorFlow中初始化LSTM单元?

时间:2017-02-27 19:01:59

标签: tensorflow

我试图在TF v1.0中运行以下代码,它会抛出错误。我创建了一个lstm,然后定义了状态变量以将其传递给lstm_cell以计算输出,但是状态无法初始化:

lstm_cell = tf.contrib.rnn.BasicLSTMCell(10)
# Initial state of the LSTM memory.
state = tf.zeros([20, lstm_cell.state_size])
outputs, states = lstm_cell(x , state)

这里是追溯:

ValueError                                Traceback (most recent call last)
<ipython-input-82-4a23eee1acf4> in <module>()
      1 lstm_cell = tf.contrib.rnn.BasicLSTMCell(10)
      2 # Initial state of the LSTM memory.
----> 3 state = tf.zeros([20, lstm_cell.state_size])
      4 outputs, states = lstm_cell(x , state)

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.pyc in zeros(shape, dtype, name)
   1370       output = constant(zero, shape=shape, dtype=dtype, name=name)
   1371     except (TypeError, ValueError):
-> 1372       shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
   1373       output = fill(shape, constant(zero, dtype=dtype), name=name)
   1374   assert output.dtype.base_dtype == dtype

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in convert_to_tensor(value, dtype, name, preferred_dtype)
    649       name=name,
    650       preferred_dtype=preferred_dtype,
--> 651       as_ref=False)
    652 
    653 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/ops.pyc in internal_convert_to_tensor(value, dtype, name, as_ref, preferred_dtype)
    714 
    715         if ret is None:
--> 716           ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    717 
    718         if ret is NotImplemented:

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.pyc in _constant_tensor_conversion_function(v, dtype, name, as_ref)
    174                                          as_ref=False):
    175   _ = as_ref
--> 176   return constant(v, dtype=dtype, name=name)
    177 
    178 

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.pyc in constant(value, dtype, shape, name, verify_shape)
    163   tensor_value = attr_value_pb2.AttrValue()
    164   tensor_value.tensor.CopyFrom(
--> 165       tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    166   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    167   const_tensor = g.create_op(

/Users/Saeed/anaconda/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.pyc in make_tensor_proto(values, dtype, shape, verify_shape)
    366     else:
    367       _AssertCompatible(values, dtype)
--> 368       nparray = np.array(values, dtype=np_dt)
    369       # check to them.
    370       # We need to pass in quantized values as tuples, so don't apply the shape

ValueError: setting an array element with a sequence.

1 个答案:

答案 0 :(得分:0)

这是因为在最新版本的Tensorflow中,默认state_size的{​​{1}}属性的返回值为BasicLSTMCell(Python元组)。

如果你检查source code,你会发现在元组的两个元素中都返回了相同数量的单位(它在以前的版本中沿着同一个轴连接),并且在初始化单元格时应该考虑状态。

因此,这应该可以解决问题:

LSTMStateTuple