再次将StringIndexer更改为String。 Apache Spark MLlib

时间:2019-01-08 07:13:51

标签: java apache-spark machine-learning apache-spark-sql apache-spark-mllib

我正在尝试将一些机器学习算法与Spark MLlib一起使用。

我在CSV中有不同类型的数据(字符串,浮点数,双精度...),当我尝试使用 VectorAssembler 来使用算法时出现了问题,因为我有字符串...因此我使用了 StringIndexer ,并且我的字符串被转换为数字数据,例如:

  • John-> 1.0

  • Jose-> 2.0

  • Willian-> 3.0

有没有办法再次将数字数据转换为字符串?我想再次获得原始的Strings ...我见过 IndexToString ,但不起作用。

编辑,代码如下:

在数据集 newDF 中,我拥有所有数据。

    final VectorAssembler vectorAssembler = new VectorAssembler()
            .setInputCols(new String[] { "IDPatient_index", "Temp", "SPO2Min", "SPO2Max", "BPMmin", "BPMmax",
                    "BPMavg", "SYS", "DIA", "EDAmin", "EDAmax", "EDAavg" })
            .setOutputCol("features");


    newDF.createOrReplaceTempView("tdd");

    Dataset<Row> transformedDS = sparkSession.sql(
            "SELECT CAST(IDPatient as String) IDPatient ,CAST(Temp as float) Temp, CAST(SPO2Min as float) SPO2Min, "
                    + "CAST(SPO2Max as float) SPO2Max, CAST(BPMmin as float) BPMmin, CAST(BPMmax as float) BPMmax, CAST(BPMavg as float) BPMavg, "
                    + "CAST(SYS as float) SYS, CAST(DIA as float) DIA, CAST(EDAmin as float) EDAmin, CAST(EDAmax as float) EDAmax, CAST(EDAavg as float) EDAavg "
                    + "FROM tdd");


    // StringIndexer
    StringIndexerModel indexer = new StringIndexer().setInputCol("IDPatient").setOutputCol("IDPatient_index").fit(transformedDS);


    Dataset<Row> indexed = indexer.transform(transformedDS);
    indexed.createOrReplaceTempView("tdd");

    // final dataset
    Dataset<Row> p1 = sparkSession.sql(
            "SELECT IDPatient_index, Temp, SPO2Min, SPO2Max, BPMmin, BPMmax, BPMavg, SYS, DIA, EDAmin, EDAmax, EDAavg FROM tdd");

    final Dataset<Row> featuresData = vectorAssembler.transform(p1);

    final StringIndexerModel labelIndexer = new StringIndexer().setInputCol("Temp").setOutputCol("indexedLabel")
            .fit(featuresData);

    final VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features")
            .setOutputCol("indexedFeatures").fit(featuresData);

    Dataset<Row>[] splits = featuresData.randomSplit(new double[] { 0.7, 0.3 });
    Dataset<Row> trainingFeaturesData = splits[0];
    Dataset<Row> testFeaturesData = splits[1];

    // Train a DecisionTree model.
    final DecisionTreeClassifier dt = new DecisionTreeClassifier().setLabelCol("indexedLabel")
            .setFeaturesCol("indexedFeatures");

    final IndexToString labelConverter = new IndexToString().setInputCol("prediction")
            .setOutputCol("predictedOccupancy").setLabels(labelIndexer.labels());

    IndexToString back2string = new IndexToString().setInputCol("IDPatient_index").setOutputCol("IDPatientfin")
        .setLabels(labelIndexer.labels());


    final Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[] { labelIndexer, featureIndexer, dt, labelConverter });


    final PipelineModel model = pipeline.fit(trainingFeaturesData);

    final Dataset<Row> predictions = model.transform(testFeaturesData);

    System.out.println("Compruebo el back a los strings");
    predictions.show();

    System.out.println("Example records with Predicted Temp as 0:");
    predictions.select("predictedOccupancy", "Temp", "features")
            .where(predictions.col("predictedOccupancy").equalTo(36.31)).show(10);

    System.out.println("Example records with Predicted Temp as 1:");
    predictions.select("predictedOccupancy", "Temp", "features")
            .where(predictions.col("predictedOccupancy").equalTo(36.3)).show(10);

    System.out.println("Example records with In-correct predictions:");
    predictions.select("predictedOccupancy", "Temp", "features")
            .where(predictions.col("predictedOccupancy").notEqual(predictions.col("Temp"))).show();

非常感谢你们!

0 个答案:

没有答案