我有以下Predictor
类,它实现了predict()
方法:
class Predictor {
public static int predict(double[] atts) {
if (atts.length != 3) {
return -1;
}
int i, j;
;
for (i = 0; i < 2; i++) {
double sum = 0.;
for (j = 0; j < 3; j++) {
sum += Math.log(2. * Math.PI * sigmas[i][j]);
}
double nij = -0.5 * sum;
sum = 0.;
for (j = 0; j < 3; j++) {
sum += Math.pow(atts[j] - thetas[i][j], 2.) / sigmas[i][j];
}
nij -= 0.5 * sum;
likelihoods[i] = Math.log(priors[i]) + nij;
}
double highestLikeli = Double.NEGATIVE_INFINITY;
int classIndex = -1;
for (i = 0; i < 2; i++) {
if (likelihoods[i] > highestLikeli) {
highestLikeli = likelihoods[i];
classIndex = i;
}
}
return classIndex;
}
public static void main(String[] args) {
if (args.length == 3) {
double[] atts = new double[args.length];
for (int i = 0, l = args.length; i < l; i++) {
atts[i] = Double.parseDouble(args[i]);
}
System.out.println(Predictor.predict(atts));
}
}
}
预测方法预测数字标签(即int号0
或1
)。我在Android应用程序中使用的,以便为浮点数组predict
提供这种浮点数组的数字标签。在视觉上,这看起来如下:
arrayOfFloats -> predict_method -> Label(0/1)
请注意,arrayOfFloats是一个数据流(连续我会有几个这样的数组)。
以下是代码:
public void run() { //stuff that updates ui
finalValues = String.format("%s, %s, %s, %s, %s, %s\n", rollValue, pitchValue, yawValue, gxValue, gyValue, gzValue);
//finalValues = String.format("%s, %s, %s\n", rollValue, pitchValue, yawValue);
Log.e("WalkingActivity", finalValues);
WalkingLog.setText(finalValues);
// Classifier
values[0] = rollValue;
values[1] = pitchValue;
values[2] = yawValue;
//values[3] = gxValue;
//values[4] = gyValue;
//values[5] = gzValue;
int y_pred = Brain.predict(values);
ClassifierLog.setText(Integer.toString(y_pred));
System.out.println("pred: " + Integer.toString(y_pred));
int counter = 0;
//Simple threshold
if (rollValue < -40 && rollValue > -80
&& pitchValue < 0 && yawValue > 0
&& gxValue < 1000 && gxValue > -2000
&& gyValue < 500 && gyValue > -1000
&& gzValue < 1000 && gzValue > -2000) {
ClassifierLog.setText("N");
System.out.println("COUNTER: " + counter);
} else {
counter++;
ClassifierLog.setText("W");
System.out.println("COUNTER: " + counter);
}
}
上述代码的问题在于,我总是得到相同的数字标签1
(我预测相同的数字),而不是获得不同的输出0
或1
。据我所知,我需要将所有浮点值附加到单个数组中。虽然我尝试将这样的数组传递给预测方法,但它无法正常工作。但是它不起作用,这是使用predict
类中的Predictor
方法的正确方法吗?
答案 0 :(得分:3)
获得解决方案的最佳方法是构建&amp;提供一个非常小的java类,它运行并且只包含您希望解决的问题的基本要素(包括输入)。虽然这需要您花费少量精力才能做到这一点,但它会清楚地表达您想要完成的任务。 请参阅下面的示例。
我最好的猜测是你只想传递一个包含三个值的数组并返回一个结果。如果是这种情况,那么您的第一个代码示例显示您已经正确地执行了此操作。
因此,您始终返回1的问题在于Predictor代码本身的逻辑;即sigma / theta逻辑(你没有传达它们的值)。但是,知道正确传入/传出值意味着您只需要调试sigma / theta逻辑。例如,在highestLikeli循环中的if语句之后,添加一个sysout:
System.out.format("%d %f %f\n",classIndex,highestLikeli,likelihoods[i]);
作为参考,这是一个小的java文件,显示传入的数组值和传递的预测值(这是你的问题似乎是关于)。
public class Test {
public static void main(String[] args) {
String params[][] = { { "3.3", "4.4", "5.5" }, { "5.5", "3.3", "4.4" }, { "4.4", "5.5", "3.3" } };
for (int i = 0; i < params[0].length; i++) {
System.out.format("%d: %s, %s, %s => ", i, params[i][0], params[i][1], params[i][2]);
Predictor.main(params[i]);
}
}
static class Predictor {
// Simplistic example taking 3 values and returning the index of the largest.
public static int predict(double[] atts) {
double highestLikeli = Double.NEGATIVE_INFINITY;
int classIndex = -1;
for (int i = 0; i < atts.length; i++) {
if (atts[i] > highestLikeli) {
highestLikeli = atts[i];
classIndex = i;
}
}
return classIndex;
}
public static void main(String[] args) {
if (args.length == 3) {
double[] atts = new double[args.length];
for (int i = 0; i<args.length; i++) {
atts[i] = Double.parseDouble(args[i]);
}
System.out.println(Predictor.predict(atts));
}
}
}
}
结果
0: 3.3, 4.4, 5.5 => 2 1: 5.5, 3.3, 4.4 => 0 2: 4.4, 5.5, 3.3 => 1