批处理如何与TensorFlow中的损失函数进行交互?

时间:2017-02-11 07:01:30

标签: machine-learning tensorflow neural-network

我正在使用自己的损失函数在TensorFlow中训练一个多目标神经网络,并且无法找到有关批处理如何与该功能交互的文档。

例如,我在下面有我的损失函数的片段,它接受张量/预测列表并确保它们的绝对值总和不超过一个:

def fitness(predictions,actual):

    absTensor = tf.abs(predictions)
    sumTensor = tf.reduce_sum(absTensor)
    oneTensor = tf.constant(1.0)

    isGTOne = tf.greater(sumTensor,oneTensor)

    def norm(): return predictions/sumTensor
    def unchanged(): return predictions

    predictions = tf.cond(isGTOne,norm,unchanged)

    etc...

但是当我通过一批估算时,我觉得这个损失函数正在将整个输入集合归一化为此时加1,而不是每个单独的集合总和为1. I.e.
[[.8,.8],[。8,.8]] - > [[0.25,0.25],[。25,25]]
而不是期望的 [[.8,.8],[。8,.8]] - > [[0.5,0.5],[5,0.5]]

任何人都可以澄清或放下我的怀疑吗?如果这是我的功能当前的工作方式,我该如何更改?

1 个答案:

答案 0 :(得分:2)

您必须为减速操作指定减速轴,否则所有轴都将减小。传统上这是张量的第一个维度。所以,第2行应该是这样的:

sumTensor = tf.reduce_sum(absTensor, 0)

进行更改后,您将遇到另一个问题。 sumTensor将不再是标量,因此作为tf.cond的条件将不再有意义(即,对于批处理的每个条目分支是什么意思?)。你真正想要的是tf.select,因为你并不真的想要为每批项目分支逻辑。像这样:

isGTOne = tf.greater(sumTensor,oneTensor)

norm = predictions/sumTensor

predictions = tf.select(isGTOne,norm,predictions)

但是,现在看看这个,我甚至不打扰条件正常化条目。由于您现在正在批量操作,我不认为您可以通过一次批量化一个条目来获得性能。特别是,由于分割并不是真正昂贵的副作用。不妨这样做:

def fitness(predictions,actual):

  absTensor = tf.abs(predictions)
  sumTensor = tf.reduce_sum(absTensor, 0)

  predictions = predictions/sumTensor

  etc...

希望有所帮助!