我正在构建一个Android应用,该应用基于Score变量预测疲劳。作为预测模型,我使用逻辑回归,更具体地说,我在这里找到的* @author tpeng * @author Matthieu Labas的类:https://github.com/tpeng/logistic-regression/blob/master/src/Logistic.java 问题是我得到的输出概率似乎非常不正确。
我尝试使用包含5个预测变量的代码提供的原始数据集,结果似乎好得多。我也尝试过尝试学习率和迭代,但没有成功。
/**
* Performs simple logistic regression.
* User: tpeng
* Date: 6/22/12
* Time: 11:01 PM
*
* @author tpeng
* @author Matthieu Labas
*/
public class LogisticRegression {
private static Context mContext;
/** the learning rate */
private double rate;
/** the weight to learn */
private double[] weights;
/** the number of iterations */
private int ITERATIONS = 1000;
public LogisticRegression(int n, Context context) {
this.rate = 0.0001;
weights = new double[n];
this.mContext=context;
}
private static double sigmoid(double z) {
return 1.0 / (1.0 + Math.exp(-z));
}
public void train(List<Instance> instances) {
for (int n=0; n<ITERATIONS; n++) {
double lik = 0.0;
for (int i=0; i<instances.size(); i++) {
int[] x = instances.get(i).x;
double predicted = classify(x);
int label = instances.get(i).label;
for (int j=0; j<weights.length; j++) {
weights[j] = weights[j] + rate * (label - predicted) * x[j];
}
// not necessary for learning
// lik += label * Math.log(classify(x)) + (1-label) * Math.log(1- classify(x));
}
// System.out.println("iteration: " + n + " " + Arrays.toString(weights) + " mle: " + lik);
}
}
public double classify(int[] x) {
double logit = .0;
for (int i=0; i<weights.length;i++) {
logit += weights[i] * x[i];
}
return sigmoid(logit);
}
public static class Instance {
public int label;
public int[] x;
public Instance(int label, int[] x) {
this.label = label;
this.x = x;
}
}
public static List<Instance> readDataSet(String path) {
List<Instance> dataset = new ArrayList<Instance>();
Scanner scanner = null;
AssetManager am = mContext.getAssets();
try {
InputStream is = am.open(path);
scanner = new Scanner(new InputStreamReader(is));
while(scanner.hasNextLine()) {
String line = scanner.nextLine();
if (line.startsWith("#")) {
continue;
}
String[] columns = line.split("\\s+");
// skip first column and last column is the label
int i = 1;
int[] data = new int[columns.length-2];
for (i=1; i<columns.length-1; i++) {
data[i-1] = Integer.parseInt(columns[i]);
}
int label = Integer.parseInt(columns[i]);
Instance instance = new Instance(label, data);
dataset.add(instance);
}
}catch (IOException e) {
e.printStackTrace();
} finally {
if (scanner != null)
scanner.close();
}
return dataset;
}
}
我的主要活动只是调用n = 1的logisticRegression.classify
我正在使用包含分数从0到10的自制数据集,我希望疲劳概率非常低,接近0,接近10时非常高,无论我测试的得分如何,当概率为0.5时接近0时接近0,接近10时接近0.7。我使用外部逻辑回归工具和相同的数据集进行了双重检查,得到的结果好得多。我查看了一下代码,自己没有发现任何错误,但它无法正常工作。这也是我的数据集的一部分,第二列是得分,第三列是标签:
1 10 1
2 9 1
3 10 1
4 9 1
5 5 0
6 4 0
7 3 0
8 2 0
9 10 1
10 10 1
11 7 1
12 6 1
13 6 0
14 5 0
15 10 1
16 10 1
17 9 1
18 2 0
19 1 0