我正在尝试开发一种针对文本引擎的语音,试图将CTC损失纳入其中。我正在跟踪此链接https://github.com/keras-team/keras/blob/master/examples/image_ocr.py,我的输入数据的形状为(batch_size,40,11);输出的形状为(batch_size,40,138)。我有(137 + 1)个类,其中为CTC空白保留了标签。输入和标签长度的形状(batch_size,1)。标签的形状为(batch_size,40)。我在下面分享了我的代码。
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
# the 2 is critical here since the first couple outputs of the RNN
# tend to be garbage:
#y_pred = y_pred[:, 2:, :]
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
size= x.shape[0]
#max_len = max_length_labels(labels)
input_length = np.zeros([size, 1])
label_length = np.zeros([size, 1])
audio_labels = np.ones([size, 40])
inputs = {'the_input': x_xt,
'the_labels': audio_labels,
'input_length': input_length,
'label_length': label_length,
}
outputs = {'ctc': np.zeros([size])}
input_length2=np.zeros([size, 1])
label_length2=np.zeros([size, 1])
for i in range(size):
input_length2[i] = 40
label_length2[i] = 40
input_shape=(40,11)
input_data = Input(name='the_input', shape=input_shape, dtype='float32')
inner=Bidirectional(LSTM(10,return_sequences=True))(input_data)
y_pred=Dense(138, activation='softmax')(inner)
Model(inputs=input_data, outputs=y_pred).summary()
labels = Input(name='the_labels',
shape=[40], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(
ctc_lambda_func, output_shape=(1,),
name='ctc')([y_pred, labels, input_length, label_length])
model = Model(inputs=[input_data, labels, input_length, label_length],
outputs=loss_out)
model.summary()
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
model.fit({'the_input':x, 'the_labels':y_utf, 'input_length':input_length2, 'label_length':label_length2}, {'ctc':y}, epochs=1, batch_size=100)
运行此代码时,出现以下错误!
tensorflow.python.framework.errors_impl.InvalidArgumentError: Saw a non-null label (index >= num_classes - 1) following a null label, batch: 0 num_classes: 138 labels:
[[Node: ctc/CTCLoss = CTCLoss[_class=["loc:@training/Adadelta/gradients/ctc/CTCLoss_grad/mul"], ctc_merge_repeated=true, ignore_longer_outputs_than_inputs=false, preprocess_collapse_repeated=false, _device="/job:localhost/replica:0/task:0/device:CPU:0"](ctc/Log, ctc/ToInt64, ctc/ToInt32_2, ctc/ToInt32_1)]]
可能出什么问题了?