我是tensorflow的新手(1天经验)。
我正在尝试使用小代码创建一个简单的基于GRU的RNN,其单层和隐藏大小为100,如下所示:
CanExecute
但我收到了最后一行的错误(即致电import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
# parameters
batch_size = 50
hidden_size = 100
# create network graph
input_data = tf.placeholder(tf.int32, [batch_size])
output_data = tf.placeholder(tf.int32, [batch_size])
cell = tf.nn.rnn_cell.GRUCell(hidden_size)
initial_state = cell.zero_state(batch_size, tf.float32)
hidden_state = initial_state
output_of_cell, hidden_state = cell(input_data, hidden_state)
)
cell()
我做错了什么?
答案 0 :(得分:0)
GRUCell
的呼叫运营商的输入应该是具有tf.float32
类型的二维张量。以下应该有效:
input_data = tf.placeholder(tf.float32, [batch_size, input_size])
cell = tf.nn.rnn_cell.GRUCell(hidden_size)
initial_state = cell.zero_state(batch_size, tf.float32)
hidden_state = initial_state
output_of_cell, hidden_state = cell(input_data, hidden_state)