在keras CTC实施中确定类,标签,输入和标签长度

时间:2018-10-16 16:23:12

标签: python tensorflow keras

我正在尝试开发一种针对文本引擎的语音,试图将CTC损失纳入其中。我正在跟踪此链接https://github.com/keras-team/keras/blob/master/examples/image_ocr.py,我的输入数据的形状为(batch_size,40,11);输出的形状为(batch_size,40,138)。我有(137 + 1)个类,其中为CTC空白保留了标签。输入和标签长度的形状(batch_size,1)。标签的形状为(batch_size,40)。我在下面分享了我的代码。

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be garbage:
    #y_pred = y_pred[:, 2:, :]
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

 size= x.shape[0]
 #max_len = max_length_labels(labels)
 input_length = np.zeros([size, 1])
 label_length = np.zeros([size, 1])
 audio_labels = np.ones([size, 40])




 inputs = {'the_input': x_xt,
      'the_labels': audio_labels,
      'input_length': input_length,
      'label_length': label_length,
      }


outputs = {'ctc': np.zeros([size])}
input_length2=np.zeros([size, 1])
label_length2=np.zeros([size, 1])

for i in range(size):
    input_length2[i] = 40
    label_length2[i] = 40
input_shape=(40,11)

input_data = Input(name='the_input', shape=input_shape, dtype='float32')
inner=Bidirectional(LSTM(10,return_sequences=True))(input_data)

y_pred=Dense(138, activation='softmax')(inner)



Model(inputs=input_data, outputs=y_pred).summary()

labels = Input(name='the_labels',
               shape=[40], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')

loss_out = Lambda(
    ctc_lambda_func, output_shape=(1,),
    name='ctc')([y_pred, labels, input_length, label_length])

model = Model(inputs=[input_data, labels, input_length, label_length],
              outputs=loss_out)

model.summary()
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
model.fit({'the_input':x, 'the_labels':y_utf, 'input_length':input_length2, 'label_length':label_length2}, {'ctc':y}, epochs=1, batch_size=100)

运行此代码时,出现以下错误!

tensorflow.python.framework.errors_impl.InvalidArgumentError: Saw a non-null label (index >= num_classes - 1) following a null label, batch: 0 num_classes: 138 labels: 
 [[Node: ctc/CTCLoss = CTCLoss[_class=["loc:@training/Adadelta/gradients/ctc/CTCLoss_grad/mul"], ctc_merge_repeated=true, ignore_longer_outputs_than_inputs=false, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](ctc/Log, ctc/ToInt64, ctc/ToInt32_2, ctc/ToInt32_1)]]

可能出什么问题了?

0 个答案:

没有答案