我一直在尝试学习/使用Scala进行机器学习,为此我需要将字符串变量转换为虚拟指数。
我完成它的方式是使用Scala中的StringIndexer
。在运行之前,我已使用df.na.fill("missing")
替换缺失值。即使在我跑完之后我仍然得到NullPointerException
。
我还应该做些什么或我应该检查的其他事情?我使用printSchema
仅对字符串列进行过滤,以获取运行StringIndexer
所需的列列表。
val newDf1 = reweight.na.fill("Missing")
val cat_cols = Array("highest_tier_nm", "day_of_week", "month",
"provided", "docsis", "dwelling_type_grp", "dwelling_type_cd", "market"
"bulk_flag")
val transformers: Array[org.apache.spark.ml.PipelineStage] = cat_cols
.map(cname => new StringIndexer()
.setInputCol(cname)
.setOutputCol(s"${cname}_index"))
val stages: Array[org.apache.spark.ml.PipelineStage] = transformers
val categorical = new Pipeline().setStages(stages)
val cat_reweight = categorical.fit(newDf)
答案 0 :(得分:1)
通常在使用机器学习时,您将使用一部分数据训练模型,然后使用另一部分进行测试。因此,有两种不同的方法可用于反映这一点。您只使用了fit()
,相当于训练模型(或管道)。
这意味着您的cat_reweight
不是数据框,而是PipelineModel
。 PipelineModel
具有函数transform()
,该函数采用与用于训练的格式相同的格式的数据,并将数据帧作为输出。换句话说,您应该在.transform(newDf1)
之后添加fit(newDf1)
。
另一个可能的问题是,在您的代码中,您使用的是fit(newDf)
而不是fit(newDf1)
。确保fit()
和transform()
方法都使用了正确的数据框,否则您将获得NullPointerException
。
在本地运行时它适用于我,但是,如果仍然出现错误,您可以在替换空值后再尝试cache()
,然后执行action以确保完成所有转换。< / p>
希望它有所帮助!