Spark:setNumClasses(),用于Multiclass LogisticRegressionModel的标签子集

时间:2016-03-10 00:24:40

标签: apache-spark logistic-regression apache-spark-mllib

我有一个包含id(标签)的数据库,范围从1到1040.我使用Multiclass Logistic回归来预测id。现在,如果我只想训练标签的一个子集,让我们说从800到810.当我设置setNumClasses(11)时,我得到一个错误 - 11个类。我必须始终将此方法设置为类的最大值,即1040.这样,训练模型将训练从0到1040的所有标签,这非常昂贵并且使用大量资源。

我是否正确地说明了这一点?如何通过给出setNumClasses(count_of_classes)来训练我的模型仅用于标签的子集。

final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
            .setNumClasses(811).run(train.rdd());

2 个答案:

答案 0 :(得分:4)

根据预览答案的评论,我发现第二个评论是主查询。如果设置setNumClasses(23)意味着:在列车集中,所有类都应该在(0到22)的范围内。检查(docs)。它写成:

  

:: Experimental ::设置Multinomial Logistic回归中k类分类问题的可能结果数。默认情况下,它是二元逻辑回归,因此k将设置为2。

这意味着,对于二进制逻辑回归,二进制值/类是(0和1),因此setNumClasses(2)是默认值。

如果您有其他类如2,3,4,在火车组中,对于二进制分类,它将无效。

建议的解决方案:如果您的列车集或子集包含790 - 801和900 - 910类,则将数据规范化或转换为(0到22)并将23作为setNumClasses(23)。

答案 1 :(得分:2)

你不能这样做,你提供了一组训练数据,它可能在Spark的渐变下降方法中失败了(不确定,因为你还没有提供错误信息)。

另外,Spark应该如何确定应该训练模型的800个标签?

您应该做的是仅使用您要训练模型的标签过滤掉RDD中的行。例如,假设您的标签是从0到1040的值,并且您只想训练标签0到800,您可以这样做:

  userid    pagetag     time
  111       1-2      19:00:11
  111       1-2      19:00:12
 *111       1-2      19:08:02*
 *113       1-3      13:02:04*
  113       1-2      13:04:08
 *115       1-2      14:14:22*
  115       1-2      14:18:56

@Edit:是的,它当然可以选择一组不同的标签,这只是一个例子,只需将过滤方法更改为:

val actualTrainingRDD = train.filter( _.label < 801 )
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
        .setNumClasses(801).run(train.rdd());

这是Scala,Java闭包使用train.filter( row => (row.label >= 790 && row.label < 801) ) ,对吧?