我有一个标有4个不同类别的概率值的数据,我也想以这种方式预测。但是我找不到任何算法处理这样的数据并预测每个类的概率值而不是binariez顺序。我可以在这样的问题中使用什么?
答案 0 :(得分:4)
scikit-learn分类器不支持开箱即用的概率分布培训。解决方法是将样本 K 次提供给 K 类的训练算法,使用概率分布作为sample_weight
。并非所有分类器都支持这一点,但SGDClassifier
确实适用于具有正确设置的逻辑回归模型。
为了举例,让我们做一个随机的训练集。
>>> X = np.random.randn(10, 6)
>>> p_pos = np.random.random_sample(10)
>>> p_pos
array([ 0.19751302, 0.01538067, 0.87723187, 0.63745719, 0.38188726,
0.62435933, 0.3706495 , 0.12011895, 0.61787941, 0.82476533])
现在将其提供给使用SGD训练的逻辑回归模型,两次。
>>> lr = SGDClassifier(loss="log")
>>> y = p_pos > .5
>>> lr.fit(np.vstack([X, X]), np.hstack([np.ones(10), np.zeros(10)]),
... sample_weight=np.hstack([p_pos, 1 - p_pos]))
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
fit_intercept=True, l1_ratio=0.15, learning_rate='optimal',
loss='log', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5,
random_state=None, shuffle=False, verbose=0, warm_start=False)
前面的例子是二进制LR。 Multiclass LR有点复杂。假设你有一个n_samples
概率分布的矩阵P,每个都是行向量:
>>> P = np.abs(np.random.randn(10, 4))
>>> P /= P.sum(axis=1).reshape(-1, 1) # normalize
>>> P
array([[ 0.22411769, 0.06275884, 0.25062665, 0.46249682],
[ 0.20659542, 0.06153031, 0.03973449, 0.69213978],
[ 0.20214651, 0.084988 , 0.12751119, 0.5853543 ],
[ 0.35839192, 0.30211805, 0.01093208, 0.32855796],
[ 0.34267131, 0.07151225, 0.09413323, 0.4916832 ],
[ 0.26670351, 0.30988833, 0.22118608, 0.20222208],
[ 0.00694437, 0.68845955, 0.18413326, 0.12046281],
[ 0.34344352, 0.27397581, 0.34626692, 0.03631376],
[ 0.29315434, 0.25683875, 0.14935136, 0.30065555],
[ 0.19147437, 0.22572122, 0.57924412, 0.00356029]])
现在我们有四个班级,所以我们需要将训练集提供给估算器四次。
>>> n_classes = P.shape[1]
>>> X4 = np.vstack([X for i in xrange(n_classes)])
>>> y = np.arange(n_classes).repeat(10)
>>> sample_weight = P.T.ravel()
>>> lr.fit(X4, y, sample_weight=sample_weight)
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
fit_intercept=True, l1_ratio=0.15, learning_rate='optimal',
loss='log', n_iter=5, n_jobs=1, penalty='l2', power_t=0.5,
random_state=None, shuffle=False, verbose=0, warm_start=False)