实施"发电机"是否合理?对于Keras' fit_generator方法()没有调用yield()?

时间:2018-03-11 21:35:34

标签: python keras generator

我使用Keras训练神经网络,我已经达到了我的数据集大于计算机上安装的RAM数量的程度,所以它的时间已经到了修改我的训练脚本以调用model.fit_generator()而不是model.fit(),这样我就不必立即将所有训练和验证数据加载到RAM中。

我已经做了修改,AFAICT工作正常,但有一件事让我烦恼 - fit_generator()的所有示例用法我都是在网上看到使用Python的yield功能来存储生成器的状态。我是一位老C ++程序员,并且怀疑yield这些我不完全理解的功能,因此我想明确而不是隐含地维护我的发电机状态,所以相反,我实现了我的生成器:

class DataGenerator:
   def __init__(self, inputFileName, maxExamplesPerBatch):
      self._inputFileName       = inputFileName
      self._maxExamplesPerBatch = maxExamplesPerBatch

      self._inputsFile = open(inputFileName, "rb")
      if (self._inputsFile == None):
         self._print("Couldn't open file %s to read input data" % inputFileName)
         sys.exit(10)

      self._outputsFile = open(inputFileName, "rb")   # yes, we're deliberately opening the same file twice (to avoid having to call seek() a lot)
      if (self._outputsFile == None):
         self._print("Couldn't open file %s to read output data" % inputFileName)
         sys.exit(10)

      headerInfo = struct.unpack("<4L", self._inputsFile.read(16))          
      if (headerInfo[0] != 1414676815):
         print("Bad magic number in input file [%s], aborting!" % inputFileName)
         sys.exit(10)

      self._numExamples   = headerInfo[1]  # Number of input->output rows in our data-file (typically quite large, i.e. millions)
      self._numInputs     = headerInfo[2]  # Number of input values in each row
      self._numOutputs    = headerInfo[3]  # Number of output values in row
      self.seekToTopOfData()

   def __len__(self):
      return (math.ceil(self._numExamples/self._maxExamplesPerBatch))

   def __next__(self):
      numExamplesToLoad = self._maxExamplesPerBatch
      numExamplesLeft   = self._numExamples - self._curExampleIdx
      if (numExamplesLeft < numExamplesToLoad):
         numExamplesToLoad = numExamplesLeft
      inputData  = np.reshape(np.fromfile(self._inputsFile,  dtype='<f4', count=(numExamplesToLoad*self._numInputs)),  (numExamplesToLoad, self._numInputs))
      outputData = np.reshape(np.fromfile(self._outputsFile, dtype='<f4', count=(numExamplesToLoad*self._numOutputs)), (numExamplesToLoad, self._numOutputs))
      self._curExampleIdx += numExamplesToLoad
      if (self._curExampleIdx == self._numExamples):
         self.seekToTopOfData()
      return (inputData, outputData)   # <----- NOTE return, not yield!!

   def seekToTopOfData(self):
      self._curExampleIdx = 0
      self._inputsFile.seek(16)
      self._outputsFile.seek(16+(self._numExamples*self._numInputs*4))

[...]

trainingDataGenerator   = DataGenerator(trainingInputFileName, maxExamplesPerBatch)
validationDataGenerator = DataGenerator(validationInputFileName, maxExamplesPerBatch)

model.fit_generator(generator=trainingDataGenerator, steps_per_epoch=len(trainingDataGenerator), epochs=maxEpochs, callbacks=callbacks_list, validation_data=validationDataGenerator, validation_steps=len(validationDataGenerator))

...请注意,我的__next __(自我)函数以return而不是yield结尾,并且我明确地存储了生成器的状态(通过私有 - DataGenerator对象中的成员变量)而不是隐式(通过yield魔术)。这似乎工作正常。

我的问题是,这种不寻常的方法是否会引入我应该注意的任何非显而易见的行为问题?

1 个答案:

答案 0 :(得分:1)

对您的代码进行表面检查。当您编写生成器函数并调用它时,调用将返回一个生成器,其__next__方法通常会被迭代重复调用,直到它引发StopIteration异常。

生成器是迭代器的特例。像列表这样的Iterables有一个生成迭代器的__iter__方法。

除非你想将值发送到你的生成器并将它们取出,你的DataGenerator是实现迭代器的合理方法,但要写一个你需要的迭代另一个类__iter__方法返回DataGenerator的实例。

How to implement __iter__(self) for a container object (Python)的答案可能也会有所帮助。