拟合Tensorflow2模型时形状不兼容的精度

时间:2019-04-27 22:16:06

标签: python numpy tensorflow keras tensorflow2.0

我正在Tensorfow 2.0.0-alpha0上运行文本生成模型(RNN),尽管在拟合模型时获得了损失指标,但在插入精度时却遇到了以下错误:

  

InvalidArgumentError:形状不兼容:[64]与[64,200]
  [[{{nodemetrics_4 / accuracy / Equal}}]]   [Op:__ inference_keras_scratch_graph_6491]

我尝试手动定义单个批次的准确性(预训练):

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
def accuracy(labels, logits):
    return tf.keras.metrics.sparse_categorical_accuracy(labels,l ogits)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
example_batch_acc  = accuracy(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Loss:      ", example_batch_loss.numpy().mean())
print("Accuracy:      ", example_batch_acc.numpy().mean())

输出为:

  

预测形状:(64、200、34)#(批处理大小,序列长度,vocab_size)   损失:3.5263805   准确性:0.01265625

然后我跟随:

optimizer = tf.keras.optimizers.RMSprop(lr=lr) 
model.compile(optimizer=optimizer, loss=loss, metrics =['accuracy']) 
history = model.fit(dataset, epochs=epochs, callbacks[checkpoint_callback]) 

并收到上述错误(丢失正常)。如果我在编译中尝试“精度=准确性”,则会得到:

  

提高ValueError('在此期间不支持会话关键字参数   渴望执行。您通过了:%s%(kwargs,))

有什么想法/建议吗?

1 个答案:

答案 0 :(得分:0)

accuracy不是Model.fit的标准参数-将在**kwargs下接受,然后以图形方式传递给session.run。尝试metrics=[accuracy]