请考虑以下示例代码:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3
def iterate_state(prev_state_tuple, input):
with tf.name_scope('h1'):
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
matmuladd = tf.matmul(inputs, weights) + biases
print("prev state: ",prev_state_tuple.get_shape())
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = unpacked_state
state = 0.9* prev_state + 0.1*matmuladd
output = tf.nn.relu(state)
print(" state: ", state.get_shape())
print(" output: ", output.get_shape())
concat_result = tf.concat(0,[state, output])
print (" concat return: ", concat_result.get_shape())
return concat_result
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
with tf.variable_scope('states'):
initial_state = tf.zeros([HIDDEN_1],
name='initial_state')
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
print(" init state: ",initial_state.get_shape())
print(" init out: ",initial_out.get_shape())
print(" concat: ",concat_tensor.get_shape())
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
print ("scanout shape: ", scanout.get_shape())
state, output = tf.split(0,2,scanout, name='split_scan_output')
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
iter_ = data_iter()
for i in xrange(0, 2):
print ("iteration: ",i)
input_data = iter_.next()
out,st = sess.run([output,state], feed_dict={ inputs: input_data})
我正在尝试连接并拆分内部状态并将张量输出到一起,以便它可以符合tf.scan
接口。
但是,在运行此示例时,我收到此错误:
(' init state: ', TensorShape([Dimension(20)]))
(' init out: ', TensorShape([Dimension(20)]))
(' concat: ', TensorShape([Dimension(40)]))
('prev state: ', TensorShape([Dimension(40)]))
(' state: ', TensorShape([Dimension(3), Dimension(20)]))
(' output: ', TensorShape([Dimension(3), Dimension(20)]))
(' concat return: ', TensorShape([Dimension(6), Dimension(20)]))
('scanout shape: ', TensorShape(None))
('iteration: ', 0)
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 57, in <module>
out,st = sess.run([output,state], feed_dict={ inputs: input_data})
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 340, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 564, in _run
feed_dict_string, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 637, in _do_run
target_list, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 659, in _do_call
e.code)
tensorflow.python.framework.errors.InvalidArgumentError: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 3) and num_split 2
[[Node: states/split_scan_output = Split[T=DT_FLOAT, num_split=2, _device="/job:localhost/replica:0/task:0/cpu:0"](states/split_scan_output/split_dim, states/state_scan/TensorArrayPack)]]
Caused by op u'states/split_scan_output', defined at:
File "cycles_in_graphs_with_scan.py", line 46, in <module>
state, output = tf.split(0,2,scanout, name='split_scan_output')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 525, in split
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1428, in _split
num_split=num_split, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2154, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1154, in __init__
self._traceback = _extract_stack()
虽然返回张量的大小明显为(6,20)
,但tf.scan
的返回形状似乎为None
,而错误则表示它找到了一个长度为{ 3
知道可能导致此错误的原因是什么?
答案 0 :(得分:0)
看起来tf.scan()
函数无法推断输出的静态形状,因此当您尝试将scanout
拆分为2个张量时,您会遇到运行时失败第0维。
在这种情况下,最好的办法是评估scanout
以查看其实际形状:
sess = tf.Session()
sess.run(tf.initialize_all_variables())
iter_ = data_iter()
input_data = iter_.next()
scanout_val = sess.run(scanout, feed_dict={inputs: input_data})
print("Actual shape of scanout:", scanout_val.shape)
从错误消息看,它在第0维中的大小为3,我怀疑它来自批量大小,因为tf.scan()
的输入和输出的第0维将具有相同的尺寸。一种可能性是你实际上想要拆分第一维:
state, output = tf.split(1, 2, scanout, name='split_scan_output')