Keras:在predict_generator中使用use_multiprocessing = True可以提供比预期更多的预测吗?

时间:2018-03-12 01:41:44

标签: python nlp deep-learning keras

由于内存问题,我正在处理语言建模问题并使用predict_generator函数。我面临的问题是predict_generator比输入的大小给出更多的预测。

我在predict_generator函数中提供的参数:

predictions = model.predict_generator(testDataGenerator(statements),
                                                  use_multiprocessing=True,workers=4,
                                                  steps=25,
                                                  verbose=1)

生成器功能:

def testDataGenerator(testDataFrame):
        testDataFrame.reset_index(drop=True, inplace=True)
        startPoint = 0
        endPoint = 64
        while True:
            statementSet = testDataFrame[startPoint:endPoint]
            test = buildTrainAndTestSets(statementSet)
            startPoint = endPoint
            endPoint += 64
            yield test

我总共有1568个输入,而且我一共发送了64个,但我得到了1600个预测。错误输出是:

25/25 [==============================] - 47s 2s/step
IndexError: Length of values does not match length of index

我认为我在生成函数中发送语句的方式就是问题。

1 个答案:

答案 0 :(得分:0)

如果您使用自定义生成器,则必须对预测变量的最后一步有所注意。

由于您要执行25步,批量大小为64,因此生成器希望您的数据恰好为1600,因此,我认为生成器中一个简单的 if 可以更改端点即可解决您的问题。 / p>