我已经使用Keras下载了ocr的代码,该代码应用了CRNN网络,并将CTC损失用作损失函数。
但是,我真的是CTC丢失的新手,只是在使用K.ctc_batch_cost()
时遇到了麻烦,尤其是input_length的含义。在keras文件中,
tf.keras.backend.ctc_batch_cost( y_true, y_pred, input_length, label_length )
label_length:张量(样本,1个),包含y_true中每个批处理项目的序列长度。
但是,我的问题是input_length的含义是什么?是LSTM输出的维度吗?
答案 0 :(得分:0)
T的长度应为2 * max_string_length。长度为T的y_true的所有可能编码将用于负对数损失计算。
通常是上一层输出的形状。