如何修改网络的连接主义时间分类(CTC)层也给我们一个置信度分数?

时间:2018-06-01 07:15:34

标签: tensorflow deep-learning ocr text-recognition

我试图通过训练CRNN(CNN + LSTM + CTC)模型来识别来自文字裁剪图像的单词。我很困惑如何添加置信度分数和认可的单词。我正在跟随https://github.com/TJCVRS/CRNN_Tensorflow的实施。有人可以建议我如何修改网络的连接主义时间分类(CTC)层也给我们一个置信度分数?

2 个答案:

答案 0 :(得分:0)

我现在可以想到两种解决方案:

  1. TensorFlow解码器都提供有关已识别文本分数的信息。 ctc_greedy_decoder返回neg_sum_logits,其中包含每个批处理元素的分数。 ctc_beam_search_decoder也是如此,它返回log_probabilities,其中包含每个批处理元素的每个梁的分数。
  2. 从两个解码器中的任何一个获取已识别的文本。将另一个CTC损失函数放入代码中,并将RNN输出矩阵和识别的文本输入到损失函数中。结果将是在矩阵中查看给定文本的概率(好的,你必须撤消减号和日志,但这应该很容易)。
  3. 解决方案(1)实现起来更快,更简单,但是,解决方案(2)更准确。但只要CRNN训练有素且光束搜索解码器的波束宽度足够大,差异就不会太大。

    查看以下行中的TF-CRNN代码 - 该分数已作为变量log_prob返回:https://github.com/MaybeShewill-CV/CRNN_Tensorflow/blob/master/tools/train_shadownet.py#L62

    这是一个自包含的代码示例,它说明了解决方案(2): https://gist.github.com/githubharald/8b6f3d489fc014b0faccbae8542060dc

答案 1 :(得分:0)

我自己的一个更新:

最终,我通过将预测标签传递回ctc损失函数,并对所得损失的负数取反对数,从而获得了一个分数。我发现此值比使用log_prob的反日志非常准确。