在tf中使用rnn_cell,同时获取ValueError:两个结构没有相同数量的元素

时间:2019-03-22 04:46:39

标签: python tensorflow recurrent-neural-network

给出data = tf.placeholder(tf.float32, [2, None, 3])(批处理大小*时间步长*特征大小),理想情况下,我想做tf.unstack(data, axis = 1)来获取许多张量,每个张量的形状为[2,3],以便稍后将它们馈送到rn带有

这样的for循环
for rnn_input in rnn_inputs:
    state = rnn_cell(rnn_input, state)

使用tf.nn.dynamic_rnn之类的高级API不在桌面上,所以我创建了

之类的解决方法
import tensorflow as tf


data = tf.placeholder(tf.float32, [2, None, 3])

step_number = tf.placeholder(tf.int32, None)

loop_counter_inital = tf.constant(0)

initi_state = tf.zeros([2,3], tf.float32)

def while_condition(loop_counter, rnn_states):
    return loop_counter < step_number

def while_body(loop_counter, rnn_states):
    loop_counter_current = loop_counter

    current_states = tf.gather_nd(data, tf.stack([tf.range(0, 2), tf.zeros([2], tf.int32)+loop_counter_current], axis=1))     

    cell = tf.nn.rnn_cell.BasicRNNCell(3)

    rnn_states = cell(current_states, rnn_states)

    return [loop_counter_current, rnn_states]


_, _states = tf.while_loop(while_condition, while_body, 
                   loop_vars=[loop_counter_inital, initi_state], 
                   shape_invariants=[loop_counter_inital.shape, tf.TensorShape([2, 3])])

with tf.Session() as sess:    

    sess.run(tf.global_variables_initializer())

    print (sess.run(_states, feed_dict={data:[[[3,1,6],[4,1,2]],[[5,8,1],[0,5,2]]], step_number:2 }))

这个想法是循环遍历data的每个2D张量中的每一行,以获得每个时间步的特征。我遇到错误

First structure (2 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32>]

Second structure (3 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>)]

似乎有一些相关的帖子。没有实际的工作。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:0)

您需要知道每个BasicRNNCell将实现具有签名call()的{​​{1}}。这意味着您的结果是形状(output, next_state) = call(input, state)的列表。因此,您需要执行以下操作。

((?,unit),(?,unit))

您在这里也犯了一个错误。您忘记在rnn_states = cell(current_states, rnn_states)[1] 上加1。

loop_counter_current

添加

第一个结构代表您传入的参数return [loop_counter_current+1, rnn_states] 的初始值,其中包含loop_varsloop_counter_inital的初始值。因此其结构对应于以下内容。

initi_state

第二个结构表示循环后的参数[ <tf.Tensor 'while/Identity:0' shape=() dtype=int32> #---> loop_counter_inital , <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32> #---> initi_state ] 。根据先前的错误,其结果与以下内容相对应。

loop_vars