手写文字识别(CNN + LSTM + CTC)需要RNN解释

时间:2019-03-07 14:08:09

标签: python tensorflow deep-learning handwriting

我试图理解以下代码,它们是python&tensorflow中的代码。我正在尝试实现手写文本识别。我指的是以下代码here

我不明白为什么RNN输出会通过“ atrous_conv2d”输入

这是我模型的架构,它接受CNN输入并传递到RNN流程,然后将其传递给CTC。

 def build_RNN(self, rnnIn4d):

    rnnIn3d = tf.squeeze(rnnIn4d, axis=[2])  # squeeze remove 1 dimensions, here it removes the 2nd index

    n_hidden = 256
    n_layers = 2
    cells = []

    for _ in range(n_layers):
        cells.append(tf.nn.rnn_cell.LSTMCell(num_units=n_hidden))

    stacked = tf.nn.rnn_cell.MultiRNNCell(cells)  # combine the 2 LSTMCell created

    # BxTxF -> BxTx2H
    ((fw, bw), _) = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d,
                                                    dtype=rnnIn3d.dtype)

    # BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
    concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)

    # project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
    kernel = tf.Variable(tf.truncated_normal([1, 1, n_hidden * 2, len(self.char_list) + 1], stddev=0.1))
    rnn = tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME')

    return tf.squeeze(rnn, axis=[2])

1 个答案:

答案 0 :(得分:1)

CTC损失层的输入形式为B x T x C

B-批次大小 T-输出的最大长度(由于空白字符,最大字长是原来的两倍) C-字符数+ 1(空白字符)

输入到圆环的形状为(B x T x 1 X 2T)==(批量,高度,宽度,通道) 我们使用的过滤器是(1,1,2T,C)==(高度,宽度,输入通道,输出通道)

无奈的CNN之后,我们将获得(B,T,1,C),这是CTC所需的输出

注意:由于tf是行专业的,因此在将图像输入到CNN之前我们将进行转置。

速率为1的空洞与普通转化层相同。