如何在正确的维度中使用mxnet warpctc

时间:2018-03-12 14:23:17

标签: mxnet

FC = mx.sym.FullyConnected(data=x_3,flatten=False, num_hidden=n_class)
x = mx.sym.softmax(data=FC)

sm_label = mx.sym.Reshape(data=label, shape=(-1,))
sm_label = mx.sym.Cast(data = sm_label, dtype = ‘int32’)
sm = mx.sym.WarpCTC(data=x, label=sm_label, label_length =n_len ,
input_length =rnn_length )

我的x层的形状[(32L,35L,27L)](bacthsize,input_lenth,n_class)
label的形状[(32L,21L)](batchsize,label_lenth)
warpctc
simple_bind错误。
参数:
数据:(32,1L,32L,286L)
标签:(32,21L)
运算符warpctc48出错:形状不一致,提供= [672],推断形状= [0,1]

我该怎么办?

1 个答案:

答案 0 :(得分:1)

MXNet repo有一个WarpCTC示例here。您可以使用python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss warpctc font/Ubuntu-M.ttf运行培训。在该示例中,以下是与WarpCTC运算符一起使用的预测和标签的形状:

Prediction is (10240, 11)
Label is (512,)

label_length: 4
input_length: 80

batch_size = 128
seq_length = 80

在上述情况下,

  • 预测是(batch_size * seq_length,n_class)。
  • 标签是(batch_size * label_length,)。

按照示例的说明,我建议在预测形状中调用WarpCTC =(1120,27),标签形状=(672,),label_length = 21,输入长度= 35。