正确执行逻辑回归

时间:2019-04-24 09:15:55

标签: java logistic-regression

我正在构建一个Android应用,该应用基于Score变量预测疲劳。作为预测模型,我使用逻辑回归,更具体地说,我在这里找到的* @author tpeng * @author Matthieu Labas的类:https://github.com/tpeng/logistic-regression/blob/master/src/Logistic.java 问题是我得到的输出概率似乎非常不正确。

我尝试使用包含5个预测变量的代码提供的原始数据集,结果似乎好得多。我也尝试过尝试学习率和迭代,但没有成功。

    /**
     * Performs simple logistic regression.
     * User: tpeng
     * Date: 6/22/12
     * Time: 11:01 PM
     *
     * @author tpeng
     * @author Matthieu Labas
     */
    public class LogisticRegression {
    private static Context mContext;
        /** the learning rate */
        private double rate;

/** the weight to learn */
private double[] weights;

/** the number of iterations */
private int ITERATIONS = 1000;

public LogisticRegression(int n, Context context) {
    this.rate = 0.0001;
    weights = new double[n];
    this.mContext=context;
}

private static double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
}

public void train(List<Instance> instances) {
    for (int n=0; n<ITERATIONS; n++) {
        double lik = 0.0;
        for (int i=0; i<instances.size(); i++) {
            int[] x = instances.get(i).x;
            double predicted = classify(x);
            int label = instances.get(i).label;
            for (int j=0; j<weights.length; j++) {
                weights[j] = weights[j] + rate * (label - predicted) * x[j];
            }
            // not necessary for learning
           // lik += label * Math.log(classify(x)) + (1-label) * Math.log(1- classify(x));
        }
  //      System.out.println("iteration: " + n + " " + Arrays.toString(weights) + " mle: " + lik);
    }
}

public double classify(int[] x) {
    double logit = .0;
    for (int i=0; i<weights.length;i++)  {
        logit += weights[i] * x[i];
    }
    return sigmoid(logit);
}

public static class Instance {
    public int label;
    public int[] x;

    public Instance(int label, int[] x) {
        this.label = label;
        this.x = x;
    }
}

public static List<Instance> readDataSet(String path) {
    List<Instance> dataset = new ArrayList<Instance>();
    Scanner scanner = null;
    AssetManager am = mContext.getAssets();
    try {
        InputStream is = am.open(path);
        scanner = new Scanner(new InputStreamReader(is));
        while(scanner.hasNextLine()) {
            String line = scanner.nextLine();
            if (line.startsWith("#")) {
                continue;
            }
            String[] columns = line.split("\\s+");

            // skip first column and last column is the label
            int i = 1;
            int[] data = new int[columns.length-2];
            for (i=1; i<columns.length-1; i++) {
                data[i-1] = Integer.parseInt(columns[i]);
            }
            int label = Integer.parseInt(columns[i]);
            Instance instance = new Instance(label, data);
            dataset.add(instance);
        }
    }catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (scanner != null)
            scanner.close();
    }
    return dataset;
}


    }

我的主要活动只是调用n = 1的logisticRegression.classify

我正在使用包含分数从0到10的自制数据集,我希望疲劳概率非常低,接近0,接近10时非常高,无论我测试的得分如何,当概率为0.5时接近0时接近0,接近10时接近0.7。我使用外部逻辑回归工具和相同的数据集进行了双重检查,得到的结果好得多。我查看了一下代码,自己没有发现任何错误,但它无法正常工作。这也是我的数据集的一部分,第二列是得分,第三列是标签:

testdata

1 10 1

2 9 1

3 10 1

4 9 1

5 5 0

6 4 0

7 3 0

8 2 0

9 10 1

10 10 1

11 7 1

12 6 1

13 6 0

14 5 0

15 10 1

16 10 1

17 9 1

18 2 0

19 1 0

0 个答案:

没有答案
相关问题