在_dynamic_rnn_loop中为什么我们将输入张量转换为tensor_array_ops.TensorArray

时间:2017-05-05 07:13:01

标签: tensorflow

在_dynamic_rnn_loop里面的Tensorflow dynamic_rnn实现中,为什么我们将输入张量转换为tensor_array_ops.TensorArray以及如果我们有2d数据如何将这个2d数据转换为2d TensorArray

def _create_ta(name, dtype):
  return tensor_array_ops.TensorArray(dtype=dtype,
                                      size=time_steps,
                                      tensor_array_name=base_name + name)
output_ta = tuple(_create_ta("output_%d" % i,
                           _infer_state_dtype(dtype, state))
                for i in range(len(flat_output_size)))
input_ta = tuple(_create_ta("input_%d" % i, flat_input[0].dtype)
               for i in range(len(flat_input)))

input_ta = tuple(ta.unstack(input_)
               for ta, input_ in zip(input_ta, flat_input))

def _time_step(time, output_ta_t, state):
    """Take a time step of the dynamic RNN.

    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.

    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """

    input_t = tuple(ta.read(time) for ta in input_ta)
    # Restore some shape information
    for input_, shape in zip(input_t, inputs_got_shape):
      input_.set_shape(shape[1:])

    input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t)
    call_cell = lambda: cell(input_t, state)

    if sequence_length is not None:
      (output, new_state) = _rnn_step(
          time=time,
          sequence_length=sequence_length,
          min_sequence_length=min_sequence_length,
          max_sequence_length=max_sequence_length,
          zero_output=zero_output,
          state=state,
          call_cell=call_cell,
          state_size=state_size,
          skip_conditionals=True)
    else:
      (output, new_state) = call_cell()

    # Pack state if using state tuples
    output = nest.flatten(output)

    output_ta_t = tuple(
        ta.write(time, out) for ta, out in zip(output_ta_t, output))

    return (time + 1, output_ta_t, new_state)

_, output_final_ta, final_state = control_flow_ops.while_loop(
  cond=lambda time, *_: time < time_steps,
  body=_time_step,
  loop_vars=(time, output_ta, state),
  parallel_iterations=parallel_iterations,
  swap_memory=swap_memory)

0 个答案:

没有答案