pytorch中nn.CTCLoss的序列标签用于Cudnn

时间:2019-04-18 18:02:49

标签: python pytorch

如何在pytorch 1.0中特别格式化标签以使用CTCLoss

原始论文中描述的方式是有一个额外的空白类+令牌类 在按字符分类时, 因此,如果我的原始标签是“ CAAT” 索引为C:1,A:2,T:3

对于正常的交叉熵损失,标记应为[1,2,2,3]

CTC损失应该是什么,如果我只是在原始标签中的每个令牌后面盲目添加空白令牌,标签会做什么:

  

要按字符预测->“ CAAT”

     

令牌映射->空白:0 C:1,A:2,T:3

     

标签-> [1,0,2,0,2,0,3,0]

这是正确的吗?

使用CUDNN pytorch时也提到需要以“串联形式”

如何?连贯在哪个维度?在这种情况下,目标长度是多少?

0 个答案:

没有答案