如何加快tf.nn.softmax_cross_entropy_with_logits()中交叉熵损失的计算

时间:2017-07-20 20:37:45

标签: tensorflow

我想制作一个多标签分类模型(每个示例都有多个标签,每个示例的标签数量都不固定)。例如,example1可能有类标签" X"," Y&#34 ;,而example2有类标签" X"," Y"和" Z&#34 ;.我的目标是计算这种多标签分类模型的交叉熵损失。

我的第一个解决方案是手动创建目标类的密集单热表示并计算损失。但是,当我的词汇量大小为O(10K)时,这种解决方案很慢。我想知道是否有更有效的方法来做到这一点?

[更新以提供相关代码]

## During the data input phrase
def input_fn():
    ... 
    ## target_ids is a sparseTensor
    target_ids = lookup_table.lookup(target_label_strings)

    ## change the dense_shape 
    st2 = tf.SparseTensor(indices=target_ids.indices,
                          values=target_ids.values,
                          dense_shape=[batch_size,vocab_size])

    ## Convert to dense Tensor
    st2_ordered = tf.sparse_reorder(st2)
    dt = tf.sparse_tensor_to_dense(st2_ordered)

    ## Row normalization
    dt_float = tf.cast(dt, tf.float32)
    dt_float = tf.add(dt_float, tf.constant(1e-10))

    dt_row_norm = tf.reduce_sum(dt_float, axis=1)
    target["target_ids"] = dt_float / tf.reshape(dt_row_norm, (-1,1))

    return feature_map, target

## Model training
def get_loss_fn(self, target, weights, mode):
    ...
    ## the self.final_logit is the final output layer
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
      labels=target["target_ids"], logits=self.final_logit))
    ...

感谢。

1 个答案:

答案 0 :(得分:0)

在TensorFlow中执行softmax交叉熵时处理大词汇表的最简单方法是使用tf.nn.sampled_softmax_loss