tf.contrib.rnn.static_rnn - ' Tensor'对象不可迭代

时间:2017-09-27 16:38:14

标签: tensorflow rnn

这是尝试设置简单的RNN:

import tensorflow as tf

BATCH_SIZE = 7 
SEQUENCE_LENGTH = 5
VECTOR_SIZE = 3
STATE_SIZE = 4

x = tf.placeholder(tf.int32, [BATCH_SIZE, SEQUENCE_LENGTH, VECTOR_SIZE],
                   name='input_placeholder')
y = tf.placeholder(tf.int32, [BATCH_SIZE, SEQUENCE_LENGTH,],
                   name='labels_placeholder')
init_state = tf.zeros([BATCH_SIZE, STATE_SIZE])

rnn_inputs = tf.unstack(x, axis = 1)
cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=init_state)

它会出现以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-67-cf3e657dfb23> in <module>()
     13 cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
     14 #rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=init_state)
---> 15 results__ = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=init_state)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn.py in static_rnn(cell, inputs, initial_state, dtype, sequence_length, scope)
   1235             state_size=cell.state_size)
   1236       else:
-> 1237         (output, state) = call_cell()
   1238 
   1239       outputs.append(output)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn.py in <lambda>()
   1222         varscope.reuse_variables()
   1223       # pylint: disable=cell-var-from-loop
-> 1224       call_cell = lambda: cell(input_, state)
   1225       # pylint: enable=cell-var-from-loop
   1226       if sequence_length is not None:

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in __call__(self, inputs, state, scope)
    178       with vs.variable_scope(vs.get_variable_scope(),
    179                              custom_getter=self._rnn_get_variable):
--> 180         return super(RNNCell, self).__call__(inputs, state)
    181 
    182   def _rnn_get_variable(self, getter, *args, **kwargs):

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py in __call__(self, inputs, *args, **kwargs)
    448         # Check input assumptions set after layer building, e.g. input shape.
    449         self._assert_input_compatibility(inputs)
--> 450         outputs = self.call(inputs, *args, **kwargs)
    451 
    452         # Apply activity regularization.

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in call(self, inputs, state)
    395     # Parameters of gates are concatenated into one multiply for efficiency.
    396     if self._state_is_tuple:
--> 397       c, h = state
    398     else:
    399       c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py in __iter__(self)
    474       TypeError: when invoked.
    475     """
--> 476     raise TypeError("'Tensor' object is not iterable.")
    477 
    478   def __bool__(self):

TypeError: 'Tensor' object is not iterable.

我不确定为什么会产生此错误。在SO上有这样的多个帖子,但他们没有回答这个问题。

更新

这是我在亚伦回答之后尝试的:

BATCH_SIZE = 7 
SEQUENCE_LENGTH = 5
VECTOR_SIZE = 3
STATE_SIZE = 4

x = tf.placeholder(tf.int32, [BATCH_SIZE, SEQUENCE_LENGTH, VECTOR_SIZE],
                   name='input_placeholder')
y = tf.placeholder(tf.int32, [BATCH_SIZE, SEQUENCE_LENGTH,],
                   name='labels_placeholder')
init_state_state = tf.zeros([BATCH_SIZE, STATE_SIZE], tf.int32)
init_state_input = tf.zeros([BATCH_SIZE, VECTOR_SIZE], tf.int32)

rnn_inputs = tf.unstack(x, axis = 1)
cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=(init_state_state, init_state_input))

现在我有以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-2-79fd4965953c> in <module>()
     14 rnn_inputs = tf.unstack(x, axis = 1)
     15 cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
---> 16 rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=(init_state_state, init_state_input))

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn.py in static_rnn(cell, inputs, initial_state, dtype, sequence_length, scope)
   1235             state_size=cell.state_size)
   1236       else:
-> 1237         (output, state) = call_cell()
   1238 
   1239       outputs.append(output)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn.py in <lambda>()
   1222         varscope.reuse_variables()
   1223       # pylint: disable=cell-var-from-loop
-> 1224       call_cell = lambda: cell(input_, state)
   1225       # pylint: enable=cell-var-from-loop
   1226       if sequence_length is not None:

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in __call__(self, inputs, state, scope)
    178       with vs.variable_scope(vs.get_variable_scope(),
    179                              custom_getter=self._rnn_get_variable):
--> 180         return super(RNNCell, self).__call__(inputs, state)
    181 
    182   def _rnn_get_variable(self, getter, *args, **kwargs):

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py in __call__(self, inputs, *args, **kwargs)
    448         # Check input assumptions set after layer building, e.g. input shape.
    449         self._assert_input_compatibility(inputs)
--> 450         outputs = self.call(inputs, *args, **kwargs)
    451 
    452         # Apply activity regularization.

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in call(self, inputs, state)
    399       c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
    400 
--> 401     concat = _linear([inputs, h], 4 * self._num_units, True)
    402 
    403     # i = input_gate, j = new_input, f = forget_gate, o = output_gate

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in _linear(args, output_size, bias, bias_initializer, kernel_initializer)
   1051           _BIAS_VARIABLE_NAME, [output_size],
   1052           dtype=dtype,
-> 1053           initializer=bias_initializer)
   1054     return nn_ops.bias_add(res, biases)

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(name, shape, dtype, initializer, regularizer, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
   1063       collections=collections, caching_device=caching_device,
   1064       partitioner=partitioner, validate_shape=validate_shape,
-> 1065       use_resource=use_resource, custom_getter=custom_getter)
   1066 get_variable_or_local_docstring = (
   1067     """%s

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(self, var_store, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    960           collections=collections, caching_device=caching_device,
    961           partitioner=partitioner, validate_shape=validate_shape,
--> 962           use_resource=use_resource, custom_getter=custom_getter)
    963 
    964   def _get_partitioned_variable(self,

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in get_variable(self, name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource, custom_getter)
    358           reuse=reuse, trainable=trainable, collections=collections,
    359           caching_device=caching_device, partitioner=partitioner,
--> 360           validate_shape=validate_shape, use_resource=use_resource)
    361     else:
    362       return _true_getter(

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\rnn_cell_impl.py in _rnn_get_variable(self, getter, *args, **kwargs)
    181 
    182   def _rnn_get_variable(self, getter, *args, **kwargs):
--> 183     variable = getter(*args, **kwargs)
    184     trainable = (variable in tf_variables.trainable_variables() or
    185                  (isinstance(variable, tf_variables.PartitionedVariable) and

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in _true_getter(name, shape, dtype, initializer, regularizer, reuse, trainable, collections, caching_device, partitioner, validate_shape, use_resource)
    350           trainable=trainable, collections=collections,
    351           caching_device=caching_device, validate_shape=validate_shape,
--> 352           use_resource=use_resource)
    353 
    354     if custom_getter is not None:

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in _get_single_variable(self, name, shape, dtype, initializer, regularizer, partition_info, reuse, trainable, collections, caching_device, validate_shape, use_resource)
    723           caching_device=caching_device,
    724           dtype=variable_dtype,
--> 725           validate_shape=validate_shape)
    726     self._vars[name] = v
    727     logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py in __init__(self, initial_value, trainable, collections, validate_shape, caching_device, name, variable_def, dtype, expected_shape, import_scope)
    197           name=name,
    198           dtype=dtype,
--> 199           expected_shape=expected_shape)
    200 
    201   def __repr__(self):

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variables.py in _init_from_args(self, initial_value, trainable, collections, validate_shape, caching_device, name, dtype, expected_shape)
    275             with ops.name_scope("Initializer"),  ops.device(None):
    276               self._initial_value = ops.convert_to_tensor(
--> 277                   initial_value(), name="initial_value", dtype=dtype)
    278               shape = (self._initial_value.get_shape()
    279                        if validate_shape else tensor_shape.unknown_shape())

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\variable_scope.py in <lambda>()
    699           initializer = initializer(dtype=dtype)
    700         init_val = lambda: initializer(  # pylint: disable=g-long-lambda
--> 701             shape.as_list(), dtype=dtype, partition_info=partition_info)
    702         variable_dtype = dtype.base_dtype
    703 

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\ops\init_ops.py in __call__(self, shape, dtype, partition_info, verify_shape)
    201       verify_shape = self._verify_shape
    202     return constant_op.constant(self.value, dtype=dtype, shape=shape,
--> 203                                 verify_shape=verify_shape)
    204 
    205   def get_config(self):

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\constant_op.py in constant(value, dtype, shape, name, verify_shape)
    100   tensor_value = attr_value_pb2.AttrValue()
    101   tensor_value.tensor.CopyFrom(
--> 102       tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape))
    103   dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    104   const_tensor = g.create_op(

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py in make_tensor_proto(values, dtype, shape, verify_shape)
    374       nparray = np.empty(shape, dtype=np_dt)
    375     else:
--> 376       _AssertCompatible(values, dtype)
    377       nparray = np.array(values, dtype=np_dt)
    378       # check to them.

~\AppData\Local\conda\conda\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_util.py in _AssertCompatible(values, dtype)
    300     else:
    301       raise TypeError("Expected %s, got %s of type '%s' instead." %
--> 302                       (dtype.name, repr(mismatch), type(mismatch).__name__))
    303 
    304 

TypeError: Expected int32, got 0.0 of type 'float' instead.

很抱歉有另一个更新,但在切换到float32后,我收到另一个错误:

BATCH_SIZE = 7 
SEQUENCE_LENGTH = 5
VECTOR_SIZE = 3
STATE_SIZE = 4

x = tf.placeholder(tf.float32, [BATCH_SIZE, SEQUENCE_LENGTH, VECTOR_SIZE],
                   name='input_placeholder')
y = tf.placeholder(tf.float32, [BATCH_SIZE, SEQUENCE_LENGTH,],
                   name='labels_placeholder')
init_state_state = tf.zeros([BATCH_SIZE, STATE_SIZE], tf.float32)
init_state_input = tf.zeros([BATCH_SIZE, VECTOR_SIZE], tf.float32)

rnn_inputs = tf.unstack(x, axis = 1)
cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE, state_is_tuple = True)
rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell, rnn_inputs, initial_state=(init_state_state, init_state_input))

以下是错误消息的最后一行:

ValueError: Trying to share variable rnn/basic_lstm_cell/kernel, but specified shape (7, 16) and found shape (6, 16).

1 个答案:

答案 0 :(得分:2)

这是因为你的LSTM初始状态应该是一个元组,而你只给它一个零向量。使init状态成为两个零向量的元组,你应该没问题。

在简单的RNN或GRU中,只有一个状态向量。 LSTM在该州有两个向量。