Tensorflow CTC损失:ctc_merge_repeated参数

时间:2017-08-08 12:22:44

标签: python c++ tensorflow ocr

我正在使用Tensorflow 1.0及其CTC丢失[1]。 在训练时,我有时会得到“找不到有效路径”。警告(这会损害学习)。由于其他Tensorflow用户有时会报告较高的学习率,因此

在分析了一下之后,我发现了导致此警告的模式:

  • 将输入序列输入ctc_loss,长度为 seqLen
  • 使用 labelLen 字符
  • 提供标签
  • 标签中有 numRepeatedChars 重复的字符,其中我将“ab”视为0,将“aa”视为1,将“aaa”视为2,依此类推
  • 发生
  • 警告时: seqLen - labelLen < numRepeatedChars

三个例子:

  • Ex.1:label =“abb”,len(label)= 3,len(inputSequence)= 3 => (3-3 = 0)< 1为真 - >警告
  • 例2:label =“abb”,len(label)= 3,len(inputSequence)= 4 => (4-3 = 1)< 1为假 - >没有警告
  • 例3:label =“bbb”,len(标签)= 3,len(inputSequence)= 4 => (4-3 = 1)< 2为真 - >警告

当我现在设置ctc_loss参数ctc_merge_repeated = False时,警告消失。

三个问题:

  • Q1:为什么重复的字符出现时会出现警告?我想,只要输入序列不短于目标标签,就没有问题。当重复的字符在标签中合并时,它会变得更短,因此输入序列不短的条件仍然存在。
  • Q2:为什么ctc_loss在默认设置下会产生此警告?重复的字符在域中是常见的,使用CTC,例如手写文本识别(HTR)
  • 问题3:做HTR时我应该使用哪些设置?当然标签可以有重复的字符。因此ctc_merge_repeated = False是有道理的。有什么建议?

重现警告的Python程序:

import tensorflow as tf
import numpy as np

def createGraph():
    tinputs=tf.placeholder(tf.float32, [100, 1, 65]) # max 100 time steps, 1 batch element, 64+1 classes
    tlabels=tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2])) # labels
    tseqLen=tf.placeholder(tf.int32, [None]) # list of sequence length in batch
    tloss=tf.reduce_mean(tf.nn.ctc_loss(labels=tlabels, inputs=tinputs, sequence_length=tseqLen, ctc_merge_repeated=True)) # ctc loss
    return (tinputs, tlabels, tseqLen, tloss)

def getNextBatch(nc): # next batch with given number of chars in label
    indices=[[0,i] for i in range(nc)]
    values=[i%65 for i in range(nc)]
    values[0]=0
    values[1]=0 # TODO: (un)comment this to trigger warning
    shape=[1, nc]
    labels=tf.SparseTensorValue(indices, values, shape)
    seqLen=[nc]
    inputs=np.random.rand(100, 1, 65)
    return (labels, inputs, seqLen) 


(tinputs, tlabels, tseqLen, tloss)=createGraph()

sess=tf.Session()
sess.run(tf.global_variables_initializer())

nc=3 # number of chars in label
print('next batch with 1 element has label len='+str(nc))
(labels, inputs, seqLen)=getNextBatch(nc)
res=sess.run([tloss], { tlabels: labels, tinputs:inputs, tseqLen:seqLen } )

这是C ++ Tensorflow代码[2],警告来自:

// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
    LOG(WARNING) << "No valid path found.";
    dy_b = y;
    return;
}

[1] https://www.tensorflow.org/versions/r1.0/api_docs/python/tf/nn/ctc_loss

[2] https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/ctc/ctc_loss_calculator.cc

1 个答案:

答案 0 :(得分:1)

好吧,明白了,这不是一个错误,这就是CTC的工作方式:让我们举一个警告发生的例子:输入序列的长度是2,标签是“aa”(也是长度2)。

现在产生“aa”的最短路径是a-&gt; blank-&gt; a(长度3)。 但是对于标记“ab”,最短路径是a-> b(长度2)。 这说明为什么对于像“aa”这样的重复标签,输入序列必须更长。它只是通过插入空白来重复标签在CTC中编码的方式。

标签重复因此在修复输入大小时会减少允许标签的最大长度。