这是我的以下代码。您可以更改激活值,但不能更改目标。 tensorflow ctc_loss始终返回Inf。
我需要ctc_loss返回一些浮动值而不是Inf。我知道,INF的原因(即倍增时的激活变得非常小,因此,如果为Inf,则记录为对数)。我对解决这个问题很感兴趣。
更改inputs
或执行其他操作,除了修改target
以获得非Inf值。
我的代码:
inputs = tf.random_uniform([1, 9, 11]) # Do Not Change the Shape of inputs
target = tf.constant([[2,2]], dtype=tf.int32)
zero = tf.constant(0, dtype=tf.float32)
where = tf.ones(tf.shape(target))
indices = tf.where(where)
values = tf.gather_nd(target, indices)
sparse = tf.SparseTensor(indices, values, target.shape)
seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 2)
loss = tf.nn.ctc_loss(sparse, inputs, seq, time_major=False, ctc_merge_repeated=True)
with tf.Session() as sess:
print (loss.eval())
答案 0 :(得分:0)
问题是您为labels
使用的ctc_loss
包含重复项。设置preprocess_collapse_repeated=True
即可解决,如其doc所述:
“如果preprocess_collapse_repeated为True,则在损失计算之前运行预处理步骤,其中传递到损失的重复标签合并为单个标签。如果训练标签来自(例如,强制比对)并因此具有不必要的重复,则这很有用。”
答案 1 :(得分:0)
这个问题有点老了,但是由于我遇到了同样的问题,并且当前的唯一答案似乎并不能解决(尽管很简单,所以从一开始我就认为这是正确的答案)并投票赞成!但是随后在其他示例中失败了……),这是一些几乎完整的答案。
TL; DR:要使问题中的代码起作用,请将seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 2)
替换为seq = tf.multiply(tf.ones([tf.shape(target)[0]], dtype=tf.int32), 9)
,即 logits 的长度和而不是 target 。
详细信息::on the manual似乎没有得到适当的记录,而且我也不清楚为什么应该这样做(可以通过多种方式轻松获得logit的形状) tf.shape
...),但这似乎可以达到预期的效果。这样做的原因使我有些逃脱。
一个简单的实验证实了sequence_length
:运行后,记录实际上已被截断了
indices = [[0, 0],
[0, 1],
[0, 2]]
values = [0, 1, 1]
shape = [1, 3]
tf.nn.ctc_loss(
tf.SparseTensor(indices, values, shape),
[[[ 10., -10., -10.],
[-10., -10., 10.],
[-10., 10., -10.],
[-10., -10., 10.],
[-10., 10., -10.]]],
[5],
time_major=False,
)
几乎应有0的损失(将对数中的10替换为更大的数字,则1越来越接近0)。在这里,作为logits传递的参数是单热编码矢量[0、2、1、2、1]的pre-softmax版本的近似值,其中2是分隔令牌,因此正确解码为目标“ 0 1 1”。
但是,将传递给[5]
的{{1}}替换为sequence_length
仍然会造成有限的损失,但损失不会为零。原因是它读取[4]
,它会解码为错误的字符串“ 0 1”。实际上,切换行以使logit表示[0, 2, 1, 2]
或[0, 1, 2, 1, 2]
会导致相同的损失,即传递[0, 1, 2, 1, 1]
或[4]
。另外,如果预测[5]
传递[0, 1, 2, 1, 0]
会导致0损失(因为忽略了最后一个0),正确的话,如果传递[4]
会导致非零损耗(因为它随后解码为“ 0” 1 0“,这是错误的。)
所有这些实验都将是完美的……如果它们解释了问题中原始示例为何给出无限的原因。取而代之的是,从实验上看,似乎在原始问题中将[5]
替换为[2]
可以使它起作用,但是确实可以将其替换为[9]
,[3]
等。部分在于,即使使用[4]
,它也不应该真正返回[2]
:它应该裁剪到logit的前两个项,并使用这些项来预测某些内容(永远不要inf,因为随机logit是始终在0到1之间)。我唯一能给自己的答案是,实际上,inf
必须总是严格地大于目标的长度。从这个意义上讲,另一个建议的答案有效:如果通过删除重复项对标签进行预处理,则“ 2 2”中的目标序列将变为“ 2”,因此足够短,不会引发此问题。但是,设置sequence_length
并不是正确的方法,因为它会弹each预测带有重复字符的字符串(例如,preprocess_collapse_repeated=True
地面真相将变为hello world
,并且logit会实际预测helo world
将会受到惩罚)
如果有人对此有更多了解,我将不胜感激!