我使用Keras训练神经网络,我已经达到了我的数据集大于计算机上安装的RAM数量的程度,所以它的时间已经到了修改我的训练脚本以调用model.fit_generator()而不是model.fit(),这样我就不必立即将所有训练和验证数据加载到RAM中。
我已经做了修改,AFAICT工作正常,但有一件事让我烦恼 - fit_generator()的所有示例用法我都是在网上看到使用Python的yield
功能来存储生成器的状态。我是一位老C ++程序员,并且怀疑yield
这些我不完全理解的功能,因此我想明确而不是隐含地维护我的发电机状态,所以相反,我实现了我的生成器:
class DataGenerator:
def __init__(self, inputFileName, maxExamplesPerBatch):
self._inputFileName = inputFileName
self._maxExamplesPerBatch = maxExamplesPerBatch
self._inputsFile = open(inputFileName, "rb")
if (self._inputsFile == None):
self._print("Couldn't open file %s to read input data" % inputFileName)
sys.exit(10)
self._outputsFile = open(inputFileName, "rb") # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
if (self._outputsFile == None):
self._print("Couldn't open file %s to read output data" % inputFileName)
sys.exit(10)
headerInfo = struct.unpack("<4L", self._inputsFile.read(16))
if (headerInfo[0] != 1414676815):
print("Bad magic number in input file [%s], aborting!" % inputFileName)
sys.exit(10)
self._numExamples = headerInfo[1] # Number of input->output rows in our data-file (typically quite large, i.e. millions)
self._numInputs = headerInfo[2] # Number of input values in each row
self._numOutputs = headerInfo[3] # Number of output values in row
self.seekToTopOfData()
def __len__(self):
return (math.ceil(self._numExamples/self._maxExamplesPerBatch))
def __next__(self):
numExamplesToLoad = self._maxExamplesPerBatch
numExamplesLeft = self._numExamples - self._curExampleIdx
if (numExamplesLeft < numExamplesToLoad):
numExamplesToLoad = numExamplesLeft
inputData = np.reshape(np.fromfile(self._inputsFile, dtype='<f4', count=(numExamplesToLoad*self._numInputs)), (numExamplesToLoad, self._numInputs))
outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
self._curExampleIdx += numExamplesToLoad
if (self._curExampleIdx == self._numExamples):
self.seekToTopOfData()
return (inputData, outputData) # <----- NOTE return, not yield!!
def seekToTopOfData(self):
self._curExampleIdx = 0
self._inputsFile.seek(16)
self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))
[...]
trainingDataGenerator = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)
model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))
...请注意,我的__next __(自我)函数以return
而不是yield
结尾,并且我明确地存储了生成器的状态(通过私有 - DataGenerator对象中的成员变量)而不是隐式(通过yield
魔术)。这似乎工作正常。
我的问题是,这种不寻常的方法是否会引入我应该注意的任何非显而易见的行为问题?
答案 0 :(得分:1)
对您的代码进行表面检查。当您编写生成器函数并调用它时,调用将返回一个生成器,其__next__
方法通常会被迭代重复调用,直到它引发StopIteration
异常。
生成器是迭代器的特例。像列表这样的Iterables有一个生成迭代器的__iter__
方法。
除非你想将值发送到你的生成器并将它们取出,你的DataGenerator
是实现迭代器的合理方法,但要写一个你需要的迭代另一个类__iter__
方法返回DataGenerator
的实例。
How to implement __iter__(self) for a container object (Python)的答案可能也会有所帮助。