Keras CTC的输入批次大小问题,用于序列分类

时间:2020-09-16 07:36:54

标签: python tensorflow keras ctc

我想对输入图像执行OCR任务,输入图像的长度可能有10个可变长度。这是我的训练数据:

X : array of float (25000, 32, 200, 1)
X_lengths : array of int (25000,), containing identical values
Y : array of int (25000, 9), encoded numerical representations of digits padded to max sequence length
Y_lengths : array of int (25000, 9), with the respective length of output sequences

现在,我将模型定义如下:

inputs = Input(shape=(32, 200, 1)) # Size of respective input images

conv_1 = Conv2D(64, (3,3), activation = 'relu', padding='same')(inputs)
pool_1 = MaxPool2D(pool_size=(2, 2), strides=2)(conv_1)
conv_2 = Conv2D(128, (3,3), activation = 'relu', padding='same')(pool_1)
pool_2 = MaxPool2D(pool_size=(2, 2), strides=2)(conv_2)
conv_3 = Conv2D(256, (3,3), activation = 'relu', padding='same')(pool_2)
conv_4 = Conv2D(256, (3,3), activation = 'relu', padding='same')(conv_3)
pool_4 = MaxPool2D(pool_size=(2, 1))(conv_4)
conv_5 = Conv2D(512, (3,3), activation = 'relu', padding='same')(pool_4)
batch_norm_5 = BatchNormalization()(conv_5)    
conv_6 = Conv2D(512, (3,3), activation = 'relu', padding='same')(batch_norm_5)
batch_norm_6 = BatchNormalization()(conv_6)
pool_6 = MaxPool2D(pool_size=(2, 1))(batch_norm_6)    
conv_7 = Conv2D(512, (2,2), activation = 'relu')(pool_6)

squeezed = Lambda(lambda x: K.squeeze(x, 1))(conv_7)

blstm_1 = Bidirectional(LSTM(128, return_sequences=True, dropout = 0.2))(squeezed)
blstm_2 = Bidirectional(LSTM(128, return_sequences=True, dropout = 0.2))(blstm_1)

outputs = Dense(11, activation = 'softmax')(blstm_2)

act_model = Model(inputs, outputs)

labels = Input(name='the_labels', shape=[9], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')

def ctc_lambda_func(args):
    y_pred, labels, input_length, label_length = args
    return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([outputs, labels, input_length, label_length])
model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = 'adam')

但是,当我尝试以这种方式拟合模型时:

model.fit(x=[X, Y, X_lengths, Y_lengths], y=np.zeros(25000), batch_size=8, epochs = 100, verbose = 1)

我收到以下错误消息:

InvalidArgumentError:  sequence_length(0) <= 49
     [[node functional_15/ctc/CTCLoss (defined at <ipython-input-750-49adc0833159>:44) ]] [Op:__inference_train_function_140562]

Function call stack:
train_function

我无法弄清楚这是什么问题。据我所知,错误消息的含义是训练批次的第二维(时间步长)不应大于49,但实际上应较小(32),因此我可以看到问题所在。

模型摘要:

_________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_14 (InputLayer)           [(None, 32, 200, 1)] 0                                            
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 32, 200, 64)  640         input_14[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_35 (MaxPooling2D) (None, 16, 100, 64)  0           conv2d_72[0][0]                  
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 16, 100, 128) 73856       max_pooling2d_35[0][0]           
__________________________________________________________________________________________________
max_pooling2d_36 (MaxPooling2D) (None, 8, 50, 128)   0           conv2d_73[0][0]                  
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 8, 50, 256)   295168      max_pooling2d_36[0][0]           
__________________________________________________________________________________________________
conv2d_75 (Conv2D)              (None, 8, 50, 256)   590080      conv2d_74[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_37 (MaxPooling2D) (None, 4, 50, 256)   0           conv2d_75[0][0]                  
__________________________________________________________________________________________________
conv2d_76 (Conv2D)              (None, 4, 50, 512)   1180160     max_pooling2d_37[0][0]           
__________________________________________________________________________________________________
batch_normalization_35 (BatchNo (None, 4, 50, 512)   2048        conv2d_76[0][0]                  
__________________________________________________________________________________________________
conv2d_77 (Conv2D)              (None, 4, 50, 512)   2359808     batch_normalization_35[0][0]     
__________________________________________________________________________________________________
batch_normalization_36 (BatchNo (None, 4, 50, 512)   2048        conv2d_77[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_38 (MaxPooling2D) (None, 2, 50, 512)   0           batch_normalization_36[0][0]     
__________________________________________________________________________________________________
conv2d_78 (Conv2D)              (None, 1, 49, 512)   1049088     max_pooling2d_38[0][0]           
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 49, 512)      0           conv2d_78[0][0]                  
__________________________________________________________________________________________________
bidirectional_6 (Bidirectional) (None, 49, 256)      656384      lambda_9[0][0]                   
__________________________________________________________________________________________________
bidirectional_7 (Bidirectional) (None, 49, 256)      394240      bidirectional_6[0][0]            
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 49, 11)       2827        bidirectional_7[0][0]            
__________________________________________________________________________________________________
the_labels (InputLayer)         [(None, 9)]          0                                            
__________________________________________________________________________________________________
input_length (InputLayer)       [(None, 1)]          0                                            
__________________________________________________________________________________________________
label_length (InputLayer)       [(None, 1)]          0                                            
__________________________________________________________________________________________________
ctc (Lambda)                    (None, 1)            0           dense_10[0][0]                   
                                                                 the_labels[0][0]                 
                                                                 input_length[0][0]               
                                                                 label_length[0][0]               
==================================================================================================

0 个答案:

没有答案