Pyspark:SVM_model.predict() - > "尺寸不匹配"

时间:2016-11-14 15:32:36

标签: python apache-spark pyspark svm apache-spark-mllib

我试图在~20k LabeledPoints的集合上测试各种候选模型,其中有〜5000个特征用于二进制分类。

我在输入空间上使用VectorAssembler创建了一个包含数字特征,稀疏矢量等的特征向量,然后将这些向量(以及标签)转换为LabeledPoints。

assembler = VectorAssembler(inputCols = ["fileSize", "hour", "day", "month", "title_tfidf", "ct_tfidf", "excerpt_tfidf", "regex_tfidf"], outputCol="features")
training_set_vector = assembler.transform(training_set).select("Target", "features")
training_set_labeled = training_set_vector.rdd.map(lambda x: Row(label=x[0],features=DenseVector(x[1].toArray()))).map(lambda row: LabeledPoint(row[1], [row[0]]))

training_set_vector和training_set_labeled的输出如下:

  

[Row(Target = 0,features = SparseVector(2273,{0:8397.0,1:13.0,2:   12.0,3:1.0,8:3.3147,82:5.721,370:0.2546,410:7.0356,418:5.3786,429:12.4219,623:7.3906,1061:9.3078,1637:2.3682,1647:2.716}))] [ LabeledPoint(0.0,[8397.0,13.0,12.0,1.0,0.0,0.0,0.0,0.0,3.31468140319,0.0,0.0,0.0,0.0,0.0,....,0.0,0.0,0.0,0.0,0.0,0.0, 0.0,0.0,0.0)]

然后我拿了LabeledPoints并适合RandomForest,SVM和GBT模型。

RF_model = RandomForest.trainClassifier(training_set_labeled, numClasses=2, categoricalFeaturesInfo = {}, numTrees=10, featureSubsetStrategy = "auto", impurity='gini', maxDepth=4, maxBins=32)
SVM_model = SVMWithSGD.train(training_set_labeled, iterations=500)
GBT_model = GradientBoostedTrees.trainClassifier(training_set_labeled, categoricalFeaturesInfo = {}, maxDepth=4, maxBins=32, numIterations=5)

到目前为止,一切都顺利进行,当我尝试将这些模型应用于测试集时(与训练集相同的维度)出现了问题。以下是我用于测试集的代码:

predictions_RF = RF_model.predict(test_set_labeled.map(lambda r: r.features))
predictions_SVM = SVM_model.predict(test_set_labeled.map(lambda r: r.features))
predictions_GBT = GBT_model.predict(test_set_labeled.map(lambda r: r.features))

RF和GBT模型成功完成:

print predictions_RF.take(5)
print predictions_GBT.take(5)
  

[0.0,0.0,0.0,0.0,0.0]   [0.0,0.0,1.0,0.0,0.0]

但是在应用SVM模型时,我收到以下错误:

  

AssertionError:尺寸不匹配

运行"打印SVM_model"显示只有~500个重量,但有~5k的功能。我假设这是问题,但我不完全确定如何处理这个问题。任何人都有类似的问题,能够指出我如何将这个SVM模型(或其他一些SVM模型)应用于测试集?

0 个答案:

没有答案