在自回归连续问题中,如果零的位置过多,则可以将情况视为零膨胀问题(即ZIB)。换句话说,不是要拟合f(x)
,而是要拟合g(x)*f(x)
,其中f(x)
是我们要近似的函数,即y
和{{1} }是一个函数,根据值是零还是非零来输出介于0和1之间的值。
目前,我有两个模型。一个给我g(x)
的模型,另一个给我g(x)
的模型。
第一个模型给了我一组权重。这是我需要您帮助的地方。我可以将g(x)*f(x)
参数和sample_weights
一起使用。在处理大量数据时,需要使用model.fit()
。但是,model.fit_generator()
没有参数fit_generator()
。
是否可以在sample_weights
内使用sample_weights
?否则,如果我已经为fit_generator()
训练了模型,那么我怎么适合g(x)*f(x)
?
答案 0 :(得分:4)
您可以提供样本权重作为生成器返回的元组的第三个元素。摘自fit_generator
上的Keras文档:
生成器::生成器或
Sequence
(keras.utils.Sequence
)对象的实例,以避免在使用多处理时重复数据。发生器的输出必须是
- 元组
(inputs, targets)
- 元组
(inputs, targets, sample_weights)
。
更新:这是生成器的草图,该生成器返回输入的样本和目标以及从模型g(x)
获得的样本权重:
def gen(args):
while True:
for i in range(num_batches):
# get the i-th batch data
inputs = ...
targets = ...
# get the sample weights
weights = g.predict(inputs)
yield inputs, targets, weights
model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)