在TensorFlow 2.0中使用GRUCell时总是出现错误

时间:2019-11-12 09:23:48

标签: python tensorflow2.0

当我将代码从TensorFlow 1.4迁移到TensorFlow 2.0时,GRUCell总是报告错误。

这是我在TensorFlow 1.4中的代码:

import tensorflow as tf

batch_size=10

depth=128

output_dim=100

inputs=tf.Variable(tf.random_normal([batch_size,depth]))

previous_state=tf.Variable(tf.random_normal([batch_size,output_dim])) 
gruCell=tf.nn.rnn_cell.GRUCell(output_dim)

output,state=gruCell(inputs,previous_state)

print(output)

print(state)

这是我在TensorFlow 2.0中的代码

import tensorflow as tf

batch_size = 10

depth = 128

output_dim = 100

inputs = tf.Variable(tf.random.normal([batch_size, depth]))

previous_state = tf.Variable(tf.random.normal([batch_size, output_dim])) 
gruCell = tf.keras.layers.GRUCell(output_dim)

output, state = gruCell(inputs, previous_state)

print(output)

print(state)

错误是:

  

文件“”,第3行,在raise_from中   tensorflow.python.framework.errors_impl.InvalidArgumentError:In [0]不是矩阵。相反,它的形状为[100] [Op:MatMul]名称:gru_cell / MatMul /

0 个答案:

没有答案