K.ctc_batch_cost()

时间:2019-03-14 11:07:12

标签: python tensorflow keras

我已经使用Keras下载了ocr的代码,该代码应用了CRNN网络,并将CTC损失用作损失函数。 但是,我真的是CTC丢失的新手,只是在使用K.ctc_batch_cost()时遇到了麻烦,尤其是input_length的含义。在keras文件中,

  

tf.keras.backend.ctc_batch_cost(       y_true,       y_pred,       input_length,       label_length   )

  1. y_true:包含真值标签的张量(样本,max_string_length)。
  2. y_pred:包含预测或softmax输出的张量(样本,time_steps,num_categories)。
  3. input_length:张量(样本,1个),包含y_pred中每个批处理项目的序列长度。
  4. label_length:张量(样本,1个),包含y_true中每个批处理项目的序列长度。

    但是,我的问题是input_length的含义是什么?是LSTM输出的维度吗?

1 个答案:

答案 0 :(得分:0)

一个示例的CTC损失是在2D阵列(T,C)上计算的。 C必须等于字符数+ 1(空白字符)。 C包含某个时间戳记中字符的概率分布。 T将是时间戳数。

T的长度应为2 * max_string_length。长度为T的y_true的所有可能编码将用于负对数损失计算。

通常是上一层输出的形状。