Spark管道上的Deeplearning:如何预测在管道中使用神经网络模型?

时间:2016-03-08 06:34:22

标签: java apache-spark neural-network deep-learning deeplearning4j

我正在尝试向Spark管道添加情绪分析程序。这样做时,我的课程延伸org.apache.spark.ml.PredictionModel。扩展此PredictionModel类时,我必须覆盖predict()方法,该方法预测给定要素的标签。但是,当我执行此代码时,我总是得到0或1.例如,如果有10个电影评论,5个是负面评论而其他5个是否定评论,则将所有评论归类为否定。我附上了以下代码。

import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;

//Model produced by a ProbabilisticClassifier
public class MovieReviewClassifierModel extends PredictionModel<Object, MovieReviewClassifierModel> implements  Serializable{


    private static final long serialVersionUID = 1L;
    private MultiLayerNetwork net;

    MovieReviewClassifierModel (MultiLayerNetwork net) throws Exception {
        this.net=net;
 }

    @Override
    public MovieReviewClassifierModel copy(ParamMap args0) {
        return null;
    }

    @Override
    public String uid() {
        return "MovieReviewClassifierModel";
    }


    public double raw2prediction(Vector rawPrediction) {//Given a vector of raw predictions, select the predicted label
        return rawPrediction.toArray()[0];
    }

    @Override
    public double predict(Object o) {

        int prediction=0;
        DenseVector v=(DenseVector)o;
        double[] a=v.toArray();
        INDArray arr=Nd4j.create(a);
        INDArray array= net.output(arr,false);
        DataBuffer ob = array.data();
        double[] d=ob.asDouble();
        double zeroProbability=d[0];
        double oneProbability=d[1];
        if (zeroProbability > oneProbability) {
            prediction=0;
        }
        else{
            prediction=1;

        }


        return prediction;
    }


}

你能告诉我错误预测的理由吗?

1 个答案:

答案 0 :(得分:0)

public double predict(Object o)中,您有以下if声明:

if (zeroProbability > oneProbability) {
    prediction=0;
}
else{
    prediction=1;

}

导致返回0或1.更改此方法以获得一些其他预测值。