张量流(python)的GRUCell中输入和隐藏状态的大小应该是多少?

时间:2016-07-06 18:44:20

标签: python tensorflow recurrent-neural-network gated-recurrent-unit

我是tensorflow的新手(1天经验)。

我正在尝试使用小代码创建一个简单的基于GRU的RNN,其单层和隐藏大小为100,如下所示:

CanExecute

但我收到了最后一行的错误(即致电import pickle import numpy as np import pandas as pd import tensorflow as tf # parameters batch_size = 50 hidden_size = 100 # create network graph input_data = tf.placeholder(tf.int32, [batch_size]) output_data = tf.placeholder(tf.int32, [batch_size]) cell = tf.nn.rnn_cell.GRUCell(hidden_size) initial_state = cell.zero_state(batch_size, tf.float32) hidden_state = initial_state output_of_cell, hidden_state = cell(input_data, hidden_state)

cell()

我做错了什么?

1 个答案:

答案 0 :(得分:0)

GRUCell的呼叫运营商的输入应该是具有tf.float32类型的二维张量。以下应该有效:

input_data = tf.placeholder(tf.float32, [batch_size, input_size])

cell = tf.nn.rnn_cell.GRUCell(hidden_size)

initial_state = cell.zero_state(batch_size, tf.float32)

hidden_state = initial_state

output_of_cell, hidden_state = cell(input_data, hidden_state)