张量流中的损失函数(带有if - else)

时间:2016-01-31 12:49:07

标签: tensorflow

我在张量流中尝试不同的损失函数。

我想要的损失函数是一种epsilon不敏感函数(这是分量):

if(|yData-yModel|<epsilon):
    loss=0
else
    loss=|yData-yModel|    

我尝试了这个解决方案:

yData=tf.placeholder("float",[None,numberOutputs]) 

yModel=model(...

epsilon=0.2
epsilonTensor=epsilon*tf.ones_like(yData)
loss=tf.maximum(tf.abs(yData-yModel)-epsilonTensor,tf.zeros_like(yData))
optimizer = tf.train.GradientDescentOptimizer(0.25)
train = optimizer.minimize(loss)

我也用过

optimizer = tf.train.MomentumOptimizer(0.001,0.9)

我在实现中没有发现任何错误。然而,它没有收敛,而loss = tf.square(yData-yModel)收敛并且loss = tf.maximum(tf.square(yData-yModel)-epsilonTensor,tf.zeros_like(yData))也收敛。

所以,我也尝试了一些更简单的loss = tf.abs(yData-yModel),它也没有收敛。我是否犯了一些错误,或者在零或其他方面存在abs的不可微分性问题? abs函数会发生什么?

1 个答案:

答案 0 :(得分:13)

如果您的损失类似于Loss(x)=abs(x-y),那么解决方案是SGD的一个不稳定的固定点 - 以任意接近解决方案的点开始您的最小化,下一步将增加损失。

具有稳定的定点是对像SGD这样的迭代过程的收敛的要求。实际上,这意味着您的优化将向局部最小值移动,但在足够接近之后,将以与学习速率成比例的步长跳过解决方案。这是一个玩具TensorFlow程序,用于说明问题

x = tf.Variable(0.)
loss_op = tf.abs(x-1.05)
opt = tf.train.GradientDescentOptimizer(0.1)
train_op = opt.minimize(loss_op)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
xvals = []
for i in range(20):
  unused, loss, xval = sess.run([train_op, loss_op, x])
  xvals.append(xval)
pyplot.plot(xvals)

Graph of x estimate

问题的一些解决方案:

  1. 使用更强大的解算器,例如近端梯度法
  2. 使用更多SGD友好损失功能,例如Huber Loss
  3. 使用学习率计划逐渐降低学习率
  4. 这是实现(3)上述玩具问题的方法

    x = tf.Variable(0.)
    loss_op = tf.abs(x-1.05)
    
    step = tf.Variable(0)
    learning_rate = tf.train.exponential_decay(
          0.2,   # Base learning rate.
          step,  # Current index into the dataset.
          1,     # Decay step.
          0.9    # Decay rate
    )
    
    opt = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = opt.minimize(loss_op, global_step=step)
    sess = tf.InteractiveSession()
    sess.run(tf.initialize_all_variables())
    xvals = []
    for i in range(40):
      unused, loss, xval = sess.run([train_op, loss_op, x])
      xvals.append(xval)
    pyplot.plot(xvals)
    

    enter image description here