稀疏采样Softmax Tensorflow

时间:2017-10-09 07:07:12

标签: machine-learning tensorflow

您如何将tf.nn.sparse_softmax_cross_entropy_with_logits转换为使用采样softmax 而不是常规softmax?

我有一个序列来对具有大量目标词汇表(500K字)的模型进行排序,并且它会触发OOM错误。

softmax函数的输入如下:[batch, max_time_steps, 512]

1 个答案:

答案 0 :(得分:0)

我遇到了同样的问题,可以使用以下方法解决:

        labels = tf.reshape(labels, [-1, 1])
        loss = tf.nn.sampled_softmax_loss(
            weights=self.W_softmax,
            biases=self.b_softmax,
            labels=labels,
            inputs=logits,
            num_sampled=20,
            num_true=1,
            num_classes=20000,
            partition_strategy="div")

对我来说,关键是将num_sampled=20设置得很低,因为512太多而无法容纳我的GPU内存(8GB)。