为RNN创建令牌:“ IndexError:数组的索引过多”

时间:2019-08-14 00:57:14

标签: python numpy keras tokenize

我正在从事一些机器学习项目,并找到了我想修改的文本RNN教程(如果可以再次找到它,将会链接),但是我似乎找不到我的代码出了问题。

我可以使用以下简短脚本来复制问题:

import string

from numpy import array
from keras.preprocessing.text import Tokenizer

def load_doc(path):

    file = open(path, 'r')
    text = file.read()
    file.close()
    return text

def get_tokens(text):

    tokens = text.split()

    # I would like to remove the next 4 lines
    table = str.maketrans('', '', string.punctuation)
    tokens = [w.translate(table) for w in tokens]
    tokens = [word for word in tokens if word.isalpha()]
    tokens = [word.lower() for word in tokens]

    return tokens

def get_sequences(sequence_length, token_length):
    sequences = list()
    for i in range(sequence_length, token_length):
        seq = tokens[i - sequence_length: i]
        line = ' '.join(seq)
        sequences.append(line)
    return sequences

data = load_doc("C:\data.txt")  
tokens = get_tokens(data)
sequences = get_sequences(10, len(tokens))

tok = Tokenizer()
tok.fit_on_texts(sequences)
encoded = tok.texts_to_sequences(sequences)

#sanity check
print("Prenumpy: \n" + str(encoded[:10]))
encoded = array(encoded)
print(encoded.shape)
print("Postnumpy: \n" + str(encoded[:10]))
X = encoded[:,:-1]

当我从get_tokens(text)删除行时,出现错误:IndexError: too many indices for array。比较形状,工作代码为:(11984,10),失败代码为:(12388,)。我了解使用工作代码,因为我将拥有更少的令牌,因此第一维将更小。另一个重要区别是将encoded转换为numpy数组之前和之后的输出。

预制(缩短空间,每次运行之间的数字不同,因为使用了不同的令牌):

[[2, 1318, 3, 161, 203, 9, 17, 297, 11, 21], ..., [21, 8, 325, 11, 12, 212, 48, 9, 235, 577]]

后期制作(缩短):

[[   2 1285    3  139  181    9   17  275   11   21]
 [1285    3  139  181    9   17  275   11   21    8]
 ...
 [  11   21    8  303   11   12  190   46    9  213]
 [  21    8  303   11   12  190   46    9  213  554]]

后播失败(缩短):

[list([2, 1318, 3, 161, 203, 9, 17, 297, 11, 21])
 list([1318, 3, 161, 203, 9, 17, 297, 11, 21, 8])
 ...
 list([11, 21, 8, 325, 11, 12, 212, 48, 9, 235])
 list([21, 8, 325, 11, 12, 212, 48, 9, 235, 577])]

我假设我得到此错误的原因是因为numpy转换数组的方式不同。为什么numpy更改数据类型?

0 个答案:

没有答案