我试图理解以下代码,它们是python&tensorflow中的代码。我正在尝试实现手写文本识别。我指的是以下代码here
我不明白为什么RNN输出会通过“ atrous_conv2d”输入
这是我模型的架构,它接受CNN输入并传递到RNN流程,然后将其传递给CTC。
def build_RNN(self, rnnIn4d):
rnnIn3d = tf.squeeze(rnnIn4d, axis=[2]) # squeeze remove 1 dimensions, here it removes the 2nd index
n_hidden = 256
n_layers = 2
cells = []
for _ in range(n_layers):
cells.append(tf.nn.rnn_cell.LSTMCell(num_units=n_hidden))
stacked = tf.nn.rnn_cell.MultiRNNCell(cells) # combine the 2 LSTMCell created
# BxTxF -> BxTx2H
((fw, bw), _) = tf.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnnIn3d,
dtype=rnnIn3d.dtype)
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H
concat = tf.expand_dims(tf.concat([fw, bw], 2), 2)
# project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC
kernel = tf.Variable(tf.truncated_normal([1, 1, n_hidden * 2, len(self.char_list) + 1], stddev=0.1))
rnn = tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME')
return tf.squeeze(rnn, axis=[2])
答案 0 :(得分:1)
CTC损失层的输入形式为B x T x C
B-批次大小 T-输出的最大长度(由于空白字符,最大字长是原来的两倍) C-字符数+ 1(空白字符)
输入到圆环的形状为(B x T x 1 X 2T)==(批量,高度,宽度,通道) 我们使用的过滤器是(1,1,2T,C)==(高度,宽度,输入通道,输出通道)
无奈的CNN之后,我们将获得(B,T,1,C),这是CTC所需的输出
注意:由于tf是行专业的,因此在将图像输入到CNN之前我们将进行转置。
速率为1的空洞与普通转化层相同。