我的简单损失函数导致NAN

时间:2018-08-08 03:36:58

标签: python-3.x tensorflow nan loss-function

我已经为自己写了一个客户流失,但经过几个步骤,流失变得微不足道,我的代码是

def my_loss(label_batch, logits_batch, alpha=1.3, beta=0.5):
    softmax_logits_batch = tf.nn.softmax(logits_batch, axis=-1)

    indices_not_0 = tf.where(tf.not_equal(label_batch, 0))  # not-zero indices
    indices_0 = tf.where(tf.equal(label_batch, 0))  # zero indices

    predict_not_0 = tf.gather_nd(softmax_logits_batch, indices_not_0)
    predict_0 = tf.gather_nd(softmax_logits_batch, indices_0)
    avg_p_not_0 = tf.reduce_mean(predict_not_0, axis=0)
    avg_p_0 = tf.reduce_mean(predict_0, axis=0)

    euclidean_distance = tf.sqrt(tf.reduce_sum(tf.square(avg_p_0 - avg_p_not_0)))
    max_value = tf.maximum(alpha - euclidean_distance, 0)
    return max_value

背后的一些基本思想是:

  1. 我的损失是因为语义细分只有2个类别。

  2. label_batch的形状为(?, H,W),其中所有值均为0或1。logits_batch的形状为(?, H,W ,2)的值是FCN的日志(不带Softmax)

  3. 我想查找全部的logits值( predict_0 predict_not_0 ),其标签值分别为<1或1。 strong> indices_0 或 indices_not_0

  4. predict_not_0和predict_0的形状均应为(?, 2)

  5. 分别计算predict_not_0和predict_0的平均值(代表类别0和类别1的欧几里得空间的中心点坐标)。它们的形状应为(2,)

  6. 计算两个中心点坐标之间的欧几里得距离,并且应大于某个特定的 alpha 值(例如alpha = 1.3)

现在,问题出在以下几个步骤之后,损耗值变为nan。

代码的输出是(我使用了一个很小的学习率参数)

Epoch[0],step[1],train batch loss = 2.87282,train acc = 0.486435.
Epoch[0],step[2],train batch loss = 2.87282,train acc = 0.485756.
Epoch[0],step[3],train batch loss = 2.87281,train acc = 0.485614.
Epoch[0],step[4],train batch loss = 2.87282,train acc = 0.485649.
Epoch[0],step[5],train batch loss = 2.87282,train acc = 0.485185.
Epoch[0],step[6],train batch loss = 2.87279,train acc = 0.485292.
Epoch[0],step[7],train batch loss = 2.87281,train acc = 0.485222.
Epoch[0],step[8],train batch loss = 2.87282,train acc = 0.484989.
Epoch[0],step[9],train batch loss = 2.87282,train acc = 0.48406.
Epoch[0],step[10],train batch loss = 2.8728,train acc = 0.483306.
Epoch[0],step[11],train batch loss = 2.87281,train acc = 0.483426.
Epoch[0],step[12],train batch loss = 2.8728,train acc = 0.482954.
Epoch[0],step[13],train batch loss = 2.87281,train acc = 0.482535.
Epoch[0],step[14],train batch loss = 2.87281,train acc = 0.482225.
Epoch[0],step[15],train batch loss = 2.87279,train acc = 0.482005.
Epoch[0],step[16],train batch loss = 2.87281,train acc = 0.48182.
Epoch[0],step[17],train batch loss = 2.87282,train acc = 0.48169.
Epoch[0],step[18],train batch loss = 2.8728,train acc = 0.481279.
Epoch[0],step[19],train batch loss = 2.87281,train acc = 0.480878.
Epoch[0],step[20],train batch loss = 2.87281,train acc = 0.480607.
Epoch[0],step[21],train batch loss = 2.87278,train acc = 0.480186.
Epoch[0],step[22],train batch loss = 2.87281,train acc = 0.479925.
Epoch[0],step[23],train batch loss = 2.87282,train acc = 0.479617.
Epoch[0],step[24],train batch loss = 2.87282,train acc = 0.479378.
Epoch[0],step[25],train batch loss = 2.87281,train acc = 0.479496.
Epoch[0],step[26],train batch loss = 2.87281,train acc = 0.479354.
Epoch[0],step[27],train batch loss = 2.87282,train acc = 0.479262.
Epoch[0],step[28],train batch loss = 2.87282,train acc = 0.479308.
Epoch[0],step[29],train batch loss = 2.87282,train acc = 0.479182.
Epoch[0],step[30],train batch loss = 2.22282,train acc = 0.478985.
Epoch[0],step[31],train batch loss = nan,train acc = 0.494112.
Epoch[0],step[32],train batch loss = nan,train acc = 0.508811.
Epoch[0],step[33],train batch loss = nan,train acc = 0.523289.
Epoch[0],step[34],train batch loss = nan,train acc = 0.536233.
Epoch[0],step[35],train batch loss = nan,train acc = 0.548851.
Epoch[0],step[36],train batch loss = nan,train acc = 0.561351.
Epoch[0],step[37],train batch loss = nan,train acc = 0.573149.
Epoch[0],step[38],train batch loss = nan,train acc = 0.584382.
Epoch[0],step[39],train batch loss = nan,train acc = 0.595006.
Epoch[0],step[40],train batch loss = nan,train acc = 0.605065.
Epoch[0],step[41],train batch loss = nan,train acc = 0.614475.
Epoch[0],step[42],train batch loss = nan,train acc = 0.623371.
Epoch[0],step[43],train batch loss = nan,train acc = 0.632092.
Epoch[0],step[44],train batch loss = nan,train acc = 0.640199.
Epoch[0],step[45],train batch loss = nan,train acc = 0.647391.

我之前使用的代码完全相同,除了损失函数是tf.nn.sparse_softmax_cross_entropy_with_logits(),并且一切正常,所以我想我的新损失函数有问题。

我有一个猜测,也许某些批处理数据仅具有一个类别的标签(只有0或1),所以 predict_not_0和predict_0 中的一个将因此没有数据,但是我不知道如何进行编码以验证 predict_not_0和predict_0

中是否有数据

有人可以帮助我找到问题所在,如何改善损失函数以避免nan?

2 个答案:

答案 0 :(得分:1)

这可能是由于使用了tf.sqrt,它的坏特性是爆炸梯度接近0。因此,随着收敛,您将逐渐遇到更多的数值不稳定性。

解决方案是摆脱tf.sqrt。例如,您可以最小化平方欧氏距离。

另一个错误的潜在来源是tf.reduce_mean,如果对空列表进行操作,则可能返回NaN。您需要弄清楚发生这种情况时希望蒙受的损失。

答案 1 :(得分:0)

nan是由0.0/0.0log(0.0)或许多编程语言中的某些其他计算引起的,因为浮点数计算通常以很大或很小的数目(由于精度而将其视为Infinity或零) )。

tf.nn.softmax在训练时不够安全,请尝试使用其他功能,例如tf.log_softmaxtf.softmax_cross_entropy_with_logits等。