在Keras中自定义损失,其中softmax为一热

时间:2018-06-26 10:44:09

标签: tensorflow keras nlp loss-function

我有一个输出Softmax的模型,我想开发一个自定义损失函数。所需的行为将是:

1)将Softmax设置为一热(通常我执行numpy.argmax(softmax_vector)并将空向量中的索引设置为1,但这在损失函数中是不允许的。)

2)将所得的一热点向量乘以我的嵌入矩阵,以获得一个嵌入向量(在我的上下文中:与给定单词相关联的单词向量,其中单词已被标记并分配给索引或类用于Softmax输出)。

3)将该向量与目标进行比较(这可能是正常的Keras损失函数)。

我知道一般如何编写自定义损失函数,但不这样做。我发现了这个closely related question(未答复),但是我的情况有些不同,因为我想保留我的softmax输出。

2 个答案:

答案 0 :(得分:1)

可以在您的客户损失函数中混合tensorflow和keras。一旦您可以访问所有Tensorflow功能,事情就会变得非常容易。我只是给你一个例子,说明这个功能的实现方式。

import tensorflow as tf
def custom_loss(target, softmax):
    max_indices = tf.argmax(softmax, -1)

    # Get the embedding matrix. In Tensorflow, this can be directly done
    # with tf.nn.embedding_lookup
    embedding_vectors = tf.nn.embedding_lookup(you_embedding_matrix, max_indices)

    # Do anything you want with normal keras loss function
    loss = some_keras_loss_function(target, embedding_vectors)

    loss = tf.reduce_mean(loss)
    return loss

答案 1 :(得分:0)

Fan Luo的答案指向了正确的方向,但最终因为它涉及到不可导数运算而无法正常工作。请注意,此类操作对于实际值是可接受的(损失函数采用实际值和预测值,不可导数运算仅适用于实际值)。

说句公道话,那是我首先要问的。 不可能做我想做的事,但我们可以得到相似且可衍生的行为:

1)softmax值的逐元素幂。这使较小的值小得多。例如,幂为4的 [0.5,0.2,0.7] 变为 [0.0625,0.0016,0.2400] 。请注意,0.2与0.7相当,但相对于0.24,0.0016可以忽略不计。 my_power越高,最终结果将更接近一击。

soft_extreme = Lambda(lambda x: x ** my_power)(softmax)

2)重要的是,softmax和一热向量都已标准化,但我们的“ soft_extreme”未标准化。首先,找到数组的总和:

norm = tf.reduce_sum(soft_extreme, 1)

3)标准化soft_extreme:

almost_one_hot = Lambda(lambda x: x / norm)(soft_extreme)

注意:在1)中将my_power设置得过高会导致NaN。如果您需要更好的softmax到一键式转换,则可以连续两次或多次执行步骤1到3。

4)最后,我们需要字典中的向量。禁止查找,但是我们可以使用矩阵乘法来取平均向量。因为我们的soft_normalized类似于一键编码,所以该平均值将类似于与最高自变量(原始预期行为)相关的向量。 (1)中的my_power越高,则它的真实性越高:

target_vectors = tf.tensordot(almost_one_hot, embedding_matrix, axes=[[1], [0]])

注意:这不能直接使用批处理!以我为例,我使用 tf.reshape 将我的“一个热门”(从 [batch,dictionary_length] 更改为 [batch,1,dictionary_length] )。 >。然后平铺我的embedding_matrix批处理时间,并最终使用:

predicted_vectors = tf.matmul(reshaped_one_hot, tiled_embedding)

可能会有更优雅的解决方案(或者如果不希望平铺嵌入矩阵,则可能需要更少的内存),所以请随时进行探索。