我正在尝试将一些机器学习算法与Spark MLlib一起使用。
我在CSV中有不同类型的数据(字符串,浮点数,双精度...),当我尝试使用 VectorAssembler 来使用算法时出现了问题,因为我有字符串...因此我使用了 StringIndexer ,并且我的字符串被转换为数字数据,例如:
John-> 1.0
Jose-> 2.0
Willian-> 3.0
有没有办法再次将数字数据转换为字符串?我想再次获得原始的Strings ...我见过 IndexToString ,但不起作用。
编辑,代码如下:
在数据集 newDF 中,我拥有所有数据。
final VectorAssembler vectorAssembler = new VectorAssembler()
.setInputCols(new String[] { "IDPatient_index", "Temp", "SPO2Min", "SPO2Max", "BPMmin", "BPMmax",
"BPMavg", "SYS", "DIA", "EDAmin", "EDAmax", "EDAavg" })
.setOutputCol("features");
newDF.createOrReplaceTempView("tdd");
Dataset<Row> transformedDS = sparkSession.sql(
"SELECT CAST(IDPatient as String) IDPatient ,CAST(Temp as float) Temp, CAST(SPO2Min as float) SPO2Min, "
+ "CAST(SPO2Max as float) SPO2Max, CAST(BPMmin as float) BPMmin, CAST(BPMmax as float) BPMmax, CAST(BPMavg as float) BPMavg, "
+ "CAST(SYS as float) SYS, CAST(DIA as float) DIA, CAST(EDAmin as float) EDAmin, CAST(EDAmax as float) EDAmax, CAST(EDAavg as float) EDAavg "
+ "FROM tdd");
// StringIndexer
StringIndexerModel indexer = new StringIndexer().setInputCol("IDPatient").setOutputCol("IDPatient_index").fit(transformedDS);
Dataset<Row> indexed = indexer.transform(transformedDS);
indexed.createOrReplaceTempView("tdd");
// final dataset
Dataset<Row> p1 = sparkSession.sql(
"SELECT IDPatient_index, Temp, SPO2Min, SPO2Max, BPMmin, BPMmax, BPMavg, SYS, DIA, EDAmin, EDAmax, EDAavg FROM tdd");
final Dataset<Row> featuresData = vectorAssembler.transform(p1);
final StringIndexerModel labelIndexer = new StringIndexer().setInputCol("Temp").setOutputCol("indexedLabel")
.fit(featuresData);
final VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features")
.setOutputCol("indexedFeatures").fit(featuresData);
Dataset<Row>[] splits = featuresData.randomSplit(new double[] { 0.7, 0.3 });
Dataset<Row> trainingFeaturesData = splits[0];
Dataset<Row> testFeaturesData = splits[1];
// Train a DecisionTree model.
final DecisionTreeClassifier dt = new DecisionTreeClassifier().setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures");
final IndexToString labelConverter = new IndexToString().setInputCol("prediction")
.setOutputCol("predictedOccupancy").setLabels(labelIndexer.labels());
IndexToString back2string = new IndexToString().setInputCol("IDPatient_index").setOutputCol("IDPatientfin")
.setLabels(labelIndexer.labels());
final Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] { labelIndexer, featureIndexer, dt, labelConverter });
final PipelineModel model = pipeline.fit(trainingFeaturesData);
final Dataset<Row> predictions = model.transform(testFeaturesData);
System.out.println("Compruebo el back a los strings");
predictions.show();
System.out.println("Example records with Predicted Temp as 0:");
predictions.select("predictedOccupancy", "Temp", "features")
.where(predictions.col("predictedOccupancy").equalTo(36.31)).show(10);
System.out.println("Example records with Predicted Temp as 1:");
predictions.select("predictedOccupancy", "Temp", "features")
.where(predictions.col("predictedOccupancy").equalTo(36.3)).show(10);
System.out.println("Example records with In-correct predictions:");
predictions.select("predictedOccupancy", "Temp", "features")
.where(predictions.col("predictedOccupancy").notEqual(predictions.col("Temp"))).show();
非常感谢你们!