在skgarden.quantile.RandomForestQuantileRegressor.fit()中无法识别出sample_weight

时间:2020-08-18 17:40:44

标签: python scikit-learn

文档中的sample_weight作为RandomForestQuantileRegressor.fit(X,y)的可用参数,位于:https://scikit-garden.github.io/api/#skgardenquantile_1,但是在尝试执行此操作时会引发错误。它适用于DecisionTreeQuantileRegressor,但不适用于RandomForestQuantileRegressor。这是故意的吗?

示例:

import numpy as np
import skgarden 

# we create 20 points
np.random.seed(0)
X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
y = [1] * 10 + [-1] * 10
sample_weight_last_ten = abs(np.random.randn(len(X)))
sample_weight_constant = np.ones(len(X))
# and bigger weights to some outliers
sample_weight_last_ten[15:] *= 5
sample_weight_last_ten[9] *= 15

# fit the model WITH WEIGHTS
clf_weights = skgarden.DecisionTreeQuantileRegressor(random_state=0)
clf_weights.fit(X, y, sample_weight=sample_weight_last_ten)

可以正确训练,但是在尝试时:

import numpy as np
import skgarden 

# we create 20 points
np.random.seed(0)
X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
y = [1] * 10 + [-1] * 10
sample_weight_last_ten = abs(np.random.randn(len(X)))
sample_weight_constant = np.ones(len(X))
# and bigger weights to some outliers
sample_weight_last_ten[15:] *= 5
sample_weight_last_ten[9] *= 15

# fit the model WITH WEIGHTS
clf_weights = skgarden.quantile.RandomForestQuantileRegressor(random_state=0)
clf_weights.fit(X, y, sample_weight=sample_weight_last_ten)

此错误弹出:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-375-dfda3c67445e> in <module>
     15 # fit the model WITH WEIGHTS
     16 clf_weights = skgarden.quantile.RandomForestQuantileRegressor(random_state=0)
---> 17 clf_weights.fit(X, y, sample_weight=sample_weight_last_ten)

TypeError: fit() got an unexpected keyword argument 'sample_weight'

训练时我可以为每个样本增加权重的另一种方式是

0 个答案:

没有答案