具有logits的张量流稀疏分类交叉熵

时间:2018-12-25 04:10:02

标签: python tensorflow

我是一名新手程序员,正在尝试遵循本指南: https://www.tensorflow.org/tutorials/sequences/text_generation 但是,我遇到了一个问题。该指南说将损失函数定义为:

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

这给了我以下错误: “ sparse_categorical_crossentropy()获得了意外的关键字参数'from_logits'”

我的意思是“ from_logits”是函数中未指定的参数,该参数受文档支持,tf.keras.losses.sparse_categorical_crossentropy()仅具有2种可能的输入。有没有一种方法可以指定正在使用的日志,还是必须要登录?

2 个答案:

答案 0 :(得分:5)

在学习本教程时,我遇到了同样的问题。我从

更改了代码
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def loss(labels, logits):
    return tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)

这解决了该问题,而无需每晚安装tf。

答案 1 :(得分:1)

from_logits参数在Tensorflow 1.13中引入。

您可以将1.12和1.13与以下网址进行比较:

https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/losses.py
https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/losses.py
在撰写本文时,

1.13尚未发布。这就是本教程以这一行开头的原因

!pip install -q tf-nightly