当我将代码从TensorFlow 1.4迁移到TensorFlow 2.0时,GRUCell总是报告错误。
这是我在TensorFlow 1.4中的代码:
import tensorflow as tf
batch_size=10
depth=128
output_dim=100
inputs=tf.Variable(tf.random_normal([batch_size,depth]))
previous_state=tf.Variable(tf.random_normal([batch_size,output_dim]))
gruCell=tf.nn.rnn_cell.GRUCell(output_dim)
output,state=gruCell(inputs,previous_state)
print(output)
print(state)
这是我在TensorFlow 2.0中的代码
import tensorflow as tf
batch_size = 10
depth = 128
output_dim = 100
inputs = tf.Variable(tf.random.normal([batch_size, depth]))
previous_state = tf.Variable(tf.random.normal([batch_size, output_dim]))
gruCell = tf.keras.layers.GRUCell(output_dim)
output, state = gruCell(inputs, previous_state)
print(output)
print(state)
错误是:
文件“”,第3行,在raise_from中 tensorflow.python.framework.errors_impl.InvalidArgumentError:In [0]不是矩阵。相反,它的形状为[100] [Op:MatMul]名称:gru_cell / MatMul /