来自Spark多项Logistic回归的意外系数

时间:2017-07-04 21:27:24

标签: apache-spark machine-learning pyspark logistic-regression multinomial

我在我的Mac,OS Sierra上运行Spark 2.1.1(这应该有用)。我尝试在我在网上找到的测试数据集上进行多项逻辑回归,我在这里报告前几行(我不知道如何在这里附加文件):

public class Foo {
    private String name;
    private Integer age;
    private Date birthday;

    @Override
    public String toString() {
        return "Foo [name=" + name + ", age=" + age + ", birthday=" + birthday + "]";
    }
}

第一列是标签('品牌',值:1,2,3),第二列和第三列是功能('性'和'年龄')。

由于标签有3个类别,因此多项逻辑回归应该执行3个二项式模型,然后从最大化该类概率的预测中选择预测。所以我希望模型返回3x2 coefficientMatrix:3,因为类是3和2,因为这些特征是2. This文档似乎与这个观点一致。

但是,惊喜......

1,0,24
1,0,26
1,0,26
1,1,27
1,1,27
3,1,27

coefficientMatrix是4x2,我有4个截距而不是3个。更奇怪的是:

>>> logit_model.coefficientMatrix
DenseMatrix(4, 2, [-1.2781, -2.8523, 0.0961, 0.5994, 0.6199, 0.9676, 0.5621, 1.2853], 1)
>>> logit_model.interceptVector
DenseVector([-4.5912, 13.0291, 1.2544, -9.6923])

由于一些奇怪的原因,模型“感觉”有4个班级,即使我只有3个班级(请参阅下面的代码进行检查)。

有什么建议吗? 非常感谢你。

以下是完整代码:

>>> logit_model.numClasses
4

以下是检查类只有3:

from pyspark.sql import functions as f
from pyspark.sql import types as t
from pyspark.ml import classification as cl
from pyspark.ml import feature as feat

customSchema = t.StructType(
    [t.StructField('brand', t.IntegerType(), True),
    t.StructField('sex', t.IntegerType(), True),
    t.StructField('age', t.IntegerType(), True)]
)

test_df01 = (
    spark
    .read
    .format('csv')
    .options(delimiter=',', header=False)
    .load('/Users/vanni/Downloads/mlogit_test.csv', schema=customSchema)
)

va = (
    feat.VectorAssembler()
    .setInputCols(['sex', 'age'])
    .setOutputCol('features')
)
test_df03 = (
    va
    .transform(test_df01)
    .drop('sex')
    .drop('age')
    .withColumnRenamed('brand', 'label')
)

logit_abst = (
    cl.LogisticRegression()
    .setFamily('multinomial')
    .setStandardization(False)
    .setThresholds([.5, .5, .5]) # to be adjusted after I know the actual values
    .setThreshold(None)
    .setMaxIter(100) # default
    .setRegParam(0.0) # default
    .setElasticNetParam(0.0) # default
    .setTol(1e-6) # default
)

logit_model = logit_abst.fit(test_df03)

1 个答案:

答案 0 :(得分:1)

这里没有什么奇怪的事。 Spark假定标签是连续的整数值,表示为DoubleType,从0开始。

由于你得到的最大标签是3,Spark假设标签实际上是0,1,2,3 - 即使数据集中从未出现过0。

如果不希望出现此行为,则应将标签重新编码为从零开始,或在原始标签上应用StringIndexer