Keras ctc_decode形状必须是等级1,但是等级2

时间:2018-06-01 11:43:30

标签: python tensorflow keras

我正在使用Keras,Tensorflow后端实现OCR。

我想使用keras.backend.ctc_decode实施。

我有一个模型类:

import keras


def ctc_lambda_func(args):
    y_pred, y_true, input_x_width, input_y_width = args
    # the 2 is critical here since the first couple outputs of the RNN
    # tend to be garbage:
    # y_pred = y_pred[:, 2:, :]
    return keras.backend.ctc_batch_cost(y_true, y_pred, input_x_width, input_y_width)


class ModelOcropy(keras.Model):
    def __init__(self, alphabet: str):
        self.img_height = 48
        self.lstm_size = 100
        self.alphabet_size = len(alphabet)

        # check backend input shape (channel first/last)
        if keras.backend.image_data_format() == "channels_first":
            input_shape = (1, None, self.img_height)
        else:
            input_shape = (None, self.img_height, 1)

        # data input
        input_x = keras.layers.Input(input_shape, name='x')

        # training inputs
        input_y = keras.layers.Input((None,), name='y')
        input_x_widths = keras.layers.Input([1], name='x_widths')
        input_y_widths = keras.layers.Input([1], name='y_widths')

        # network
        flattened_input_x = keras.layers.Reshape((-1, self.img_height))(input_x)
        bidirectional_lstm = keras.layers.Bidirectional(
            keras.layers.LSTM(self.lstm_size, return_sequences=True, name='lstm'),
            name='bidirectional_lstm'
        )(flattened_input_x)
        dense = keras.layers.Dense(self.alphabet_size, activation='relu')(bidirectional_lstm)
        y_pred = keras.layers.Softmax(name='y_pred')(dense)

        # ctc loss
        ctc = keras.layers.Lambda(ctc_lambda_func, output_shape=[1], name='ctc')(
            [dense, input_y, input_x_widths, input_y_widths]
        )

        # init keras model
        super().__init__(inputs=[input_x, input_x_widths, input_y, input_y_widths], outputs=[y_pred, ctc])

        # ctc decoder
        top_k_decoded, _ = keras.backend.ctc_decode(y_pred, input_x_widths)
        self.decoder = keras.backend.function([input_x, input_x_widths], [top_k_decoded[0]])
        # decoded_sequences = self.decoder([test_input_data, test_input_lengths])

我对ctc_decode的使用来自另一篇文章:Keras using Lambda layers error with K.ctc_decode

我收到错误:

ValueError: Shape must be rank 1 but is rank 2 for 'CTCGreedyDecoder' (op: 'CTCGreedyDecoder') with input shapes: [?,?,7], [?,1].

我想我必须挤压input_x_widths,但是Keras似乎没有这样的功能(它总是输出类似(batch_size, 1)

2 个答案:

答案 0 :(得分:1)

实际上,该函数期待一维张量,并且你有一个2D张量。

  • Keras确实具有keras.backend.squeeze(x, axis=-1)功能。
  • 您还可以使用keras.backend.reshape(x, (-1,))

如果您需要在操作后返回旧形状,则可以:

  • keras.backend.expand_dims(x)
  • keras.backend.reshape(x,(-1,1))

答案 1 :(得分:1)

完成修复:

    # ctc decoder
    flattened_input_x_width = keras.backend.reshape(input_x_widths, (-1,))
    top_k_decoded, _ = keras.backend.ctc_decode(y_pred, flattened_input_x_width)
    self.decoder = keras.backend.function([input_x, flattened_input_x_width], [top_k_decoded[0]])
    # decoded_sequences = self.decoder([input_x, flattened_input_x_width])