keras的函数拟合中有哪些参数?

时间:2018-07-25 00:40:22

标签: python tensorflow keras

我的问题是关于keras here中适合函数的参数。当我们有一个以x作为输入和y作为输出的模型时。然后,如果我们有一批带有输入x和目标输出y的样本,则可以使用model.fit(x,y)来更新模型。假设我们有一个损失函数l(x,y)model.fit(x,y)的作用是如下更新模型的权重w

  

delta = 0

     

对于批次中的每个样本i:增量<-增量+相对于w的损耗梯度(x_i,y_i)

     

w <-w +增量

现在,假设我们要更改更新规则,如下所示:

  

delta = 0

     

对于批次中的每个样本i:delta <-delta + q_i *相对于w的损耗梯度(x_i,y_i)

     

w <-w + n * delta

我想知道函数fit的哪些参数可用于对系数q_in进行建模:

  

fit(x = None,y = None,batch_size = None,epochs = 1,verbose = 1,   callbacks = None,validation_split = 0.0,validation_data = None,   shuffle = True,class_weight = None,sample_weight = None,initial_epoch = 0,   steps_per_epoch = None,validation_steps = None)

我想class_weightsample_weight中的一个应该与系数q_i有关,但我不知道哪个。任何想法?以及可以使用这些参数中的哪一个来建模n

1 个答案:

答案 0 :(得分:1)

变量q_i可以放入sample_weight中。您正在使用“针对批次中的每个样本”的权重。 (实际上,sample_weight将使损失成倍增加,但常数的作用类似于c.df/dx = d(c.f)/dx

变量n并不是很简单,但是在编译时,它与优化程序的学习率lr密切相关:

示例:

from keras.optimizers import Adam
optimizer = Adam(lr=0.001)

model.compile(optimizer = optimizer, loss = 'binary_crossentropy')    

class_weight呢?这将根据您的样本是哪个“输出类别”来应用权重。

虽然sample_weight的长度应与xsample_weight.shape[0] == x.shape[0])相同,但是class_weight应该是一个字典,例如:{0:w0, 1:w1, 2:w2}有。