Tensorflow MultiRNNcell

时间:2017-09-25 00:19:59

标签: tensorflow lstm

我对MultiRNNcell有以下问题,但首先要做的事情。

我的数据包括以下内容:

[[a1, b2,..., x200], [b1, b2, ..., b200], ...]

相关代码在这里:

  rows, row_size = 20, 10
  num_classes = 3
  batch_size = 128
  hidden_layer_size = 256
  n_layers = 4

  tf_x = tf.placeholder(tf.float32, [None, rows, row_size])
  tf_y = tf.placeholder(tf.float32, [None, num_classes])

  in_x = tf.unstack(input_x, axis=1)

  network = tf.contrib.rnn.MultiRNNCell([tf.contrib.rnn.BasicLSTMCell(hidden_layer_size, state_is_tuple=True)
                                       for _ in range(n_layers)], state_is_tuple=True)

  outputs, states = rnn.dynamic_rnn(cell=network, inputs=in_x, dtype=tf.float32)
  outputs = tf.matmul(outputs[-1], layer["weights"]) + layer["biases"]
  ...
  ...

  x_feed = np.array(x_feed.reshape((batch_size, rows, row_size)))
  _, c = sess.run([optimizer, loss_fn], feed_dict={tf_x: x_feed, tf_y: y_feed})

我收到错误ValueError: Shape (10, ?) must have rank at least 3  和回溯显示在行

  outputs, states = rnn.dynamic_rnn(cell=network, inputs=in_x, dtype=tf.float32)

outputs,states = rnn.static_rnn(cell = network,inputs = x3,dtype = tf.float32)

如果我使用static_rnn代替dynamic_rnn,一切运行正常,但我不知道自己做错了什么。在这种情况下如何使用dynamic_rnn

0 个答案:

没有答案