使用自定义softmax_loss函数时,tf.gather超出范围,即使它不应

时间:2019-03-13 16:37:27

标签: python tensorflow softmax seq2seq

我正在tf.contrib.seq2seq.sequence_loss(softmax_loss_function=[...])内使用一个小的自定义函数作为自定义sofmax_loss_function:

    def reduced_softmax_loss(self, labels, logits):
        top_logits, indices = tf.nn.top_k(logits, self.nb_top_classes, sorted=False)
        top_labels = tf.gather(labels, indices)

        return tf.nn.softmax_cross_entropy_with_logits_v2(labels=top_labels,
                                                          logits=top_logits)

但是,即使标签和日志的尺寸相同,执行后它也会返回InvalidArgumentError

indices[1500,1] = 2158 is not in [0, 1600),由于我的随机种子,数字有所不同。

还有像tf.gather这样的其他函数可以代替吗?还是返回的值是错误的形状?

如果我通过常规的Tensorflow函数,一切都可以正常工作。

提前谢谢!

1 个答案:

答案 0 :(得分:0)

仅通过查看代码就很难判断发生了什么,但是我认为您编写的代码并不能满足您的要求。 tf.gather操作需要一个索引输入,其中每个标量值都索引到第一个参数的最外层维度,但是在这里top_k的输出试图索引到行和列,这会导致超出范围的错误。