我可以训练字符级LSTM来使用预先准备的标签标记嘈杂的文本吗?

时间:2019-04-25 07:44:35

标签: python tensorflow keras lstm

我对ML还是很陌生,很难在TensorFlow和Keras上脱颖而出。我想构建一个基于字符的LSTM并训练一种语言模型,该模型可用于标记使用OCR软件捕获的嘈杂文本。 OCR进程随机插入和删除了空格。

我有1000个例句的清单。对于每个句子,我都有原始的原始OCR输出和该句子的手动更正版本。手动更正句子需要在单词之间放置空格,并删除单词中间的空格。它还涉及在标点符号之前放置空格。这样,我将能够简单地使用some_string.split(" ")将其标记化。

这里是原始文本句子的示例,以及如何手动纠正它的外观。

原始OCR:

Itwasn'ta wonderf ulday, it rainedfora while, butatle ast
wegot to takea walkand get so me freshair.

已更正:

It was n't a wonderful day , it rained for a while , but at least 
we got to take a walk and get some fresh air .

我将句子分为900个训练句子和100个测试句子。

train_text = raw_sents[:900]
train_labels = corrected_text[:900]
test_text = raw_sents[900:]
test_labels = corrected_text[900:]

我已将训练语句分为原始和更正的单个字符串。

train_text = "".join(train_text)
train_labels = "".join(train_labels)
test_text = "".join(test_text
test_labels = "".join(test_labels)

这是我无法理解我要做什么的部分。我想做的是在每个时间步将train_text中的新字符输入到模型中,并让模型输出下一个字符应该是空格还是即将出现的字母。然后,我希望模型针对train_labels中的下一个字符进行测试,以查看其是否正确,然后反向传播。经过几个时期之后,我想将test_text输入到模型中,并根据“ test_labels”来衡量其准确性。

我从根本上误解了LSTM的工作原理吗?

我已经阅读了博客文章和观看过的教程,这些教程显示了正在创建的字符级LSTM,但是在现阶段,他们都将文本分成等长的序列,然后训练语言模型以预测下一个字符训练集中的字符作为标签,而不是使用不同的标签字符串。

chars = sorted(list(set(train_labels)))
no_chars = len(chars)
mapping = dict((c, i) for i, c in enumerate(chars))
sequences = list()
encoded_seq = [mapping[char] for char in train_labels]
sequences.append(encoded_seq)
sequences = np.array(train_labels)
X, y = sequences[:, : - 1], sequences[:, - 1]
sequences = [to_categorical(x, num_classes=no_chars) for x in X]
X = array(sequences)
y = to_categorical(y, num_classes=no_chars)

model = Sequential()
model.add(LSTM(75, input_shape=(X.shape[1], X.shape[2])))
model.add(Dense(no_chars, activation='softmax'))
print(model.summary())
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X, y, epochs=1000, verbose=2)

我应该抛弃我的train_text字符串,而只是在我的train_labels上使用这种方法训练语言模型吗?

这种方法的问题在于,序列的长度似乎决定了您需要向模型中输入相同数量的字符才能使其正常工作。因此,如果我将序列长度设置为10,然后测试了上面的raw-OCR示例,它将无法解决前十个字符Itwasn'ta。有什么办法可以克服这个问题?

理想情况下,我可以使用函数中的语言模型来纠正输入到函数中的任何字符串中的所有间距问题,但是我不知道该怎么做。

0 个答案:

没有答案