在张量流中避免ctc_loss的无穷大

时间:2018-09-20 17:18:00

标签: python-3.x tensorflow loss-function

这是我的以下代码。您可以更改激活值,但不能更改目标。 tensorflow ctc_loss始终返回Inf。 我需要ctc_loss返回一些浮动值而不是Inf。我知道,INF的原因(即倍增时的激活变得非常小,因此,如果为Inf,则记录为对数)。我对解决这个问题很感兴趣。 更改inputs或执行其他操作,除了修改target以获得非Inf值。

我的代码:

inputs = tf.random_uniform([1, 9, 11]) # Do Not Change the Shape of inputs

target = tf.constant([[2,2]], dtype=tf.int32)

zero = tf.constant(0, dtype=tf.float32)
where = tf.ones(tf.shape(target))
indices = tf.where(where)
values = tf.gather_nd(target, indices)

sparse = tf.SparseTensor(indices, values, target.shape)
seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 2)
loss = tf.nn.ctc_loss(sparse, inputs, seq, time_major=False, ctc_merge_repeated=True)

with tf.Session() as sess:
    print (loss.eval())

2 个答案:

答案 0 :(得分:0)

问题是您为labels使用的ctc_loss包含重复项。设置preprocess_collapse_repeated=True即可解决,如其doc所述:

“如果preprocess_collapse_repeated为True,则在损失计算之前运行预处理步骤,其中传递到损失的重复标签合并为单个标签。如果训练标签来自(例如,强制比对)并因此具有不必要的重复,则这很有用。”

答案 1 :(得分:0)

这个问题有点老了,但是由于我遇到了同样的问题,并且当前的唯一答案似乎并不能解决(尽管很简单,所以从一开始我就认为这是正确的答案)并投票赞成!但是随后在其他示例中失败了……),这是一些几乎完整的答案。

TL; DR:要使问题中的代码起作用,请将seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 2)替换为seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 9),即 logits 的长度和而不是 target

详细信息:on the manual似乎没有得到适当的记录,而且我也不清楚为什么应该这样做(可以通过多种方式轻松获得logit的形状) tf.shape ...),但这似乎可以达到预期的效果。这样做的原因使我有些逃脱。

一个简单的实验证实了sequence_length:运行后,记录实际上已被截断了

indices = [[0, 0],
       [0, 1],
       [0, 2]]

values = [0, 1, 1]

shape = [1, 3]

tf.nn.ctc_loss(
    tf.SparseTensor(indices, values, shape),
    [[[ 10., -10., -10.],
        [-10., -10.,  10.],
        [-10.,  10., -10.],
        [-10., -10.,  10.],
        [-10.,  10., -10.]]],
    [5],
    time_major=False,
)

几乎应有0的损失(将对数中的10替换为更大的数字,则1越来越接近0)。在这里,作为logits传递的参数是单热编码矢量[0、2、1、2、1]的pre-softmax版本的近似值,其中2是分隔令牌,因此正确解码为目标“ 0 1 1”。

但是,将传递给[5]的{​​{1}}替换为sequence_length仍然会造成有限的损失,但损失不会为零。原因是它读取[4],它会解码为错误的字符串“ 0 1”。实际上,切换行以使logit表示[0, 2, 1, 2][0, 1, 2, 1, 2]会导致相同的损失,即传递[0, 1, 2, 1, 1][4]。另外,如果预测[5]传递[0, 1, 2, 1, 0]会导致0损失(因为忽略了最后一个0),正确的话,如果传递[4]会导致非零损耗(因为它随后解码为“ 0” 1 0“,这是错误的。)

所有这些实验都将是完美的……如果它们解释了问题中原始示例为何给出无限的原因。取而代之的是,从实验上看,似乎在原始问题中将[5]替换为[2]可以使它起作用,但是确实可以将其替换为[9][3]等。部分在于,即使使用[4],它也不应该真正返回[2]:它应该裁剪到logit的前两个项,并使用这些项来预测某些内容(永远不要inf,因为随机logit是始终在0到1之间)。我唯一能给自己的答案是,实际上,inf必须总是严格地大于目标的长度。从这个意义上讲,另一个建议的答案有效:如果通过删除重复项对标签进行预处理,则“ 2 2”中的目标序列将变为“ 2”,因此足够短,不会引发此问题。但是,设置sequence_length并不是正确的方法,因为它会弹each预测带有重复字符的字符串(例如,preprocess_collapse_repeated=True地面真相将变为hello world,并且logit会实际预测helo world将会受到惩罚)

如果有人对此有更多了解,我将不胜感激!