keras :: fit()错误中的class_weight参数。类存在于数据中,但不存在于class_weight中

时间:2019-04-08 15:56:57

标签: r class keras deep-learning

我要为其分配权重的数据集有些不平衡。

How to set class_weight in keras package of R?中提供的示例不适用于我。当我尝试相同的操作时,使用我的代码:

system.time ( 
  baseline_history <- fit (
    object           = model_baseline,            
    x                = as.matrix(x_train_tbl), 
    y                = y_train_vec,             
    batch_size       = 1024,    
    epochs           = 30,    
    class_weight = list("0" = 1, "1" = 1.67),
    validation_split = 0.2) )

我收到以下错误:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: `class_weight` must contain all classes in the data. The classes {'0', '1'} exist in the data but not in `class_weight`.

我有点茫然,因为我在class_weights中明确声明它是一个列表。我什至尝试过

weights <- list("0" = 1, "1" = 1.67)
> weights
$`0`
[1] 1

$`1`
[1] 1.67

is.list(weights)
[1] TRUE

为了确保它能正常工作,但是我仍然遇到相同的错误。有什么想法吗?

1 个答案:

答案 0 :(得分:0)

我假设您将y_train_vec作为一个因素,这就是问题所在。

由于某些原因,class_weight似乎不适用于因子,因此您只需将其更改为数字即可

y_train_vec = as.numeric(y_train_vec)

这应该为您提供因子的内部表示形式(应该为您提供1和2的列表),然后您可以相应地指定class_weight

system.time(
    baseline_history <- fit (
        object           = model_baseline,            
        x                = as.matrix(x_train_tbl), 
        y                = as.matrix(y_train_vec),             
        batch_size       = 1024,    
        epochs           = 30,    
        class_weight = list("1" = 1, "2" = 1.67),
        validation_split = 0.2)
    )
)

现在请注意,class_weight中的类是“ 1”和“ 2”

希望这会有所帮助。