我正在使用序列到序列模型研究文本生成,但遇到了无法解决的问题。描述:我根据问题和答案训练一个神经网络。当我准备解码数据时,我得到以下张量维度:
tokenizedAnswers = tokenizer.texts_to_sequences(answers) ### split the text responses to the sequence of indexes
for i in range(len(tokenizedAnswers)) :
tokenizedAnswers[i] = tokenizedAnswers[i][1:] #### for responses divided into sequences, we get rid of the START tag
paddedAnswers = pad_sequences(tokenizedAnswers, maxlen=maxLenAnswers, padding= 'post') # # # Making sequences of the same length, filling in shorter responses with zeros
decoderForOutput = utils.to_categorical(paddedAnswers, vocabularySize) # # # converting to one hot vector
在这个阶段,paddedAnswers 变量包含一个大小为 (4594, 25) 的二维 numpy 张量,对应于我的基数,decoderForOutput 变量包含一个三维 numpy 张量,大小为 (4594, 25, 10785)。这也匹配我的数据库。
paddedAnswers.shape ### (4594, 25)
decoderForOutput.shape ### (4594, 25, 10785)
这里,为了节省内存,我想使用Python生成器。为此,我创建了两个生成器函数
def generator_from_two_dimensional_tensor(arg):
for i in range(arg.shape[0]):
for j in range(arg.shape[1]):
yield arg[i, j]
def generator_from_three_dimensional_tensor(arg):
for i in range(arg.shape[0]):
for j in range(arg.shape[1]):
for k in range(arg.shape[2]):
yield arg[i, j, k]
我在代码中替换了这些函数,如下所示。我还从标记化的响应中制作了一个生成器:
tokenizedAnswer = tokenizer.texts_to_sequences(answers)
for i in range(len(tokenizedAnswer)):
tokenizedAnswer[i] = tokenizedAnswer[i][1:]
generator_tokenized_answers = (x for x in tokenizedAnswer)
gen_paddedAnswers = generator_from_two_dimensional_tensor(pad_sequences([x for x in generator_tokenized_answers], maxlen=maxLenAnswers , padding='post'))
decoderForOutput = generator_from_three_dimensional_tensor(utils.to_categorical([x for x in gen_paddedAnswers], vocabularySize))
代码被触发,它没有给出任何错误。但是当我尝试训练网络时,出现错误:
history = model.fit([[x for x in gen_encoderForInput], [y for y in gen_decoderForInput]], [z for z in decoderForOutput], batch_size=50, epochs=20)
<ipython-input-6-a3a67c64deaf> in generator_three_tensor(arg)
2 for i in range(arg.shape[0]):
3 for j in range(arg.shape[1]):
----> 4 for k in range(arg.shape[2]):
5 yield arg[i, j, k]
6
IndexError: tuple index out of range
为什么会发生这种情况对我来说仍然是个谜...