由于内存问题,我正在处理语言建模问题并使用predict_generator函数。我面临的问题是predict_generator比输入的大小给出更多的预测。
我在predict_generator函数中提供的参数:
predictions = model.predict_generator(testDataGenerator(statements),
use_multiprocessing=True,workers=4,
steps=25,
verbose=1)
生成器功能:
def testDataGenerator(testDataFrame):
testDataFrame.reset_index(drop=True, inplace=True)
startPoint = 0
endPoint = 64
while True:
statementSet = testDataFrame[startPoint:endPoint]
test = buildTrainAndTestSets(statementSet)
startPoint = endPoint
endPoint += 64
yield test
我总共有1568个输入,而且我一共发送了64个,但我得到了1600个预测。错误输出是:
25/25 [==============================] - 47s 2s/step
IndexError: Length of values does not match length of index
我认为我在生成函数中发送语句的方式就是问题。
答案 0 :(得分:0)
如果您使用自定义生成器,则必须对预测变量的最后一步有所注意。
由于您要执行25步,批量大小为64,因此生成器希望您的数据恰好为1600,因此,我认为生成器中一个简单的 if 可以更改端点即可解决您的问题。 / p>