将sample_weights与fit_generator()一起使用

时间:2018-11-17 19:59:29

标签: machine-learning keras time-series generator autoregressive-models

在自回归连续问题中,如果零的位置过多,则可以将情况视为零膨胀问题(即ZIB)。换句话说,不是要拟合f(x),而是要拟合g(x)*f(x),其中f(x)是我们要近似的函数,即y和{{1} }是一个函数,根据值是零还是非零来输出介于0和1之间的值。

目前,我有两个模型。一个给我g(x)的模型,另一个给我g(x)的模型。

第一个模型给了我一组权重。这是我需要您帮助的地方。我可以将g(x)*f(x)参数和sample_weights一起使用。在处理大量数据时,需要使用model.fit()。但是,model.fit_generator()没有参数fit_generator()

是否可以在sample_weights内使用sample_weights?否则,如果我已经为fit_generator()训练了模型,那么我怎么适合g(x)*f(x)

1 个答案:

答案 0 :(得分:4)

您可以提供样本权重作为生成器返回的元组的第三个元素。摘自fit_generator上的Keras文档:

  

生成器::生成器或Sequencekeras.utils.Sequence)对象的实例,以避免在使用多处理时重复数据。发生器的输出必须是

     
      
  • 元组(inputs, targets)
  •   
  • 元组(inputs, targets, sample_weights)
  •   

更新:这是生成器的草图,该生成器返回输入的样本和目标以及从模型g(x)获得的样本权重:

def gen(args):
    while True:
        for i in range(num_batches):
            # get the i-th batch data
            inputs = ...
            targets = ...

            # get the sample weights
            weights = g.predict(inputs)

            yield inputs, targets, weights


model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)