循环网络示例中的Tensorflow concat / split问题

时间:2016-06-03 10:06:16

标签: python-2.7 tensorflow recurrent-neural-network

请考虑以下示例代码:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3


def iterate_state(prev_state_tuple, input):
    with tf.name_scope('h1'):
        weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
        biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
        matmuladd = tf.matmul(inputs, weights) + biases
        print("prev state: ",prev_state_tuple.get_shape())
        unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
        prev_state = unpacked_state
        state = 0.9* prev_state + 0.1*matmuladd
        output = tf.nn.relu(state)
        print(" state: ", state.get_shape())
        print(" output: ", output.get_shape())
        concat_result = tf.concat(0,[state, output])
        print (" concat return: ", concat_result.get_shape())
        return concat_result

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    with tf.variable_scope('states'):
        initial_state = tf.zeros([HIDDEN_1],
                                 name='initial_state')
        initial_out = tf.zeros([HIDDEN_1],
                                 name='initial_out')
        concat_tensor = tf.concat(0,[initial_state, initial_out])
        print(" init state: ",initial_state.get_shape())
        print(" init out: ",initial_out.get_shape())
        print(" concat: ",concat_tensor.get_shape())
        scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
        print ("scanout shape: ", scanout.get_shape())
        state, output = tf.split(0,2,scanout, name='split_scan_output')

    sess = tf.Session()
    # Run the Op to initialize the variables.
    sess.run(tf.initialize_all_variables())
    iter_ = data_iter()
    for i in xrange(0, 2):
        print ("iteration: ",i)
        input_data = iter_.next()
        out,st = sess.run([output,state], feed_dict={ inputs: input_data})

我正在尝试连接并拆分内部状态并将张量输出到一起,以便它可以符合tf.scan接口。

但是,在运行此示例时,我收到此错误:

(' init state: ', TensorShape([Dimension(20)]))
(' init out: ', TensorShape([Dimension(20)]))
(' concat: ', TensorShape([Dimension(40)]))
('prev state: ', TensorShape([Dimension(40)]))
(' state: ', TensorShape([Dimension(3), Dimension(20)]))
(' output: ', TensorShape([Dimension(3), Dimension(20)]))
(' concat return: ', TensorShape([Dimension(6), Dimension(20)]))
('scanout shape: ', TensorShape(None))
('iteration: ', 0)
Traceback (most recent call last):
  File "cycles_in_graphs_with_scan.py", line 57, in <module>
    out,st = sess.run([output,state], feed_dict={ inputs: input_data})
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 340, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 564, in _run
    feed_dict_string, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 637, in _do_run
    target_list, options, run_metadata)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 659, in _do_call
    e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 3) and num_split 2
     [[Node: states/split_scan_output = Split[T=DT_FLOAT, num_split=2, _device="/job:localhost/replica:0/task:0/cpu:0"](states/split_scan_output/split_dim, states/state_scan/TensorArrayPack)]]
Caused by op u'states/split_scan_output', defined at:
  File "cycles_in_graphs_with_scan.py", line 46, in <module>
    state, output = tf.split(0,2,scanout, name='split_scan_output')
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 525, in split
    name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1428, in _split
    num_split=num_split, name=name)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2154, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1154, in __init__
    self._traceback = _extract_stack()

虽然返回张量的大小明显为(6,20),但tf.scan的返回形状似乎为None,而错误则表示它找到了一个长度为{ 3

知道可能导致此错误的原因是什么?

1 个答案:

答案 0 :(得分:0)

看起来tf.scan()函数无法推断输出的静态形状,因此当您尝试将scanout拆分为2个张量时,您会遇到运行时失败第0维。

在这种情况下,最好的办法是评估scanout以查看其实际形状:

sess = tf.Session()
sess.run(tf.initialize_all_variables())
iter_ = data_iter()
input_data = iter_.next()
scanout_val = sess.run(scanout, feed_dict={inputs: input_data})

print("Actual shape of scanout:", scanout_val.shape)

从错误消息看,它在第0维中的大小为3,我怀疑它来自批量大小,因为tf.scan()的输入和输出的第0维将具有相同的尺寸。一种可能性是你实际上想要拆分第一维:

state, output = tf.split(1, 2, scanout, name='split_scan_output')