将方法应用于浮点数组流时出现问题?

时间:2017-11-07 19:36:16

标签: java android arrays floating-point

我有以下Predictor类,它实现了predict()方法:

class Predictor {

    public static int predict(double[] atts) {
        if (atts.length != 3) {
            return -1;
        }
        int i, j;

        ;

        for (i = 0; i < 2; i++) {
            double sum = 0.;
            for (j = 0; j < 3; j++) {
                sum += Math.log(2. * Math.PI * sigmas[i][j]);
            }
            double nij = -0.5 * sum;
            sum = 0.;
            for (j = 0; j < 3; j++) {
                sum += Math.pow(atts[j] - thetas[i][j], 2.) / sigmas[i][j];
            }
            nij -= 0.5 * sum;
            likelihoods[i] = Math.log(priors[i]) + nij;
        }

        double highestLikeli = Double.NEGATIVE_INFINITY;
        int classIndex = -1;
        for (i = 0; i < 2; i++) {
            if (likelihoods[i] > highestLikeli) {
                highestLikeli = likelihoods[i];
                classIndex = i;
            }
        }
        return classIndex;
    }

    public static void main(String[] args) {
        if (args.length == 3) {
            double[] atts = new double[args.length];
            for (int i = 0, l = args.length; i < l; i++) {
                atts[i] = Double.parseDouble(args[i]);
            }
            System.out.println(Predictor.predict(atts));
        }
    }
}

预测方法预测数字标签(即int号01)。我在Android应用程序中使用的,以便为浮点数组predict提供这种浮点数组的数字标签。在视觉上,这看起来如下:

arrayOfFloats -> predict_method -> Label(0/1)

请注意,arrayOfFloats是一个数据流(连续我会有几个这样的数组)。

以下是代码:

public void run() { //stuff that updates ui
    finalValues = String.format("%s, %s, %s, %s, %s, %s\n", rollValue, pitchValue, yawValue, gxValue, gyValue, gzValue);
    //finalValues = String.format("%s, %s, %s\n",  rollValue, pitchValue, yawValue);

    Log.e("WalkingActivity", finalValues);

    WalkingLog.setText(finalValues);

    // Classifier
    values[0] = rollValue;
    values[1] = pitchValue;
    values[2] = yawValue;
    //values[3] = gxValue;
    //values[4] = gyValue;
    //values[5] = gzValue;

    int y_pred = Brain.predict(values);

    ClassifierLog.setText(Integer.toString(y_pred));

    System.out.println("pred: " + Integer.toString(y_pred));

    int counter = 0;

    //Simple threshold
    if (rollValue < -40 && rollValue > -80
      && pitchValue < 0 && yawValue > 0
      && gxValue < 1000 && gxValue > -2000
      && gyValue < 500 && gyValue > -1000
      && gzValue < 1000 && gzValue > -2000) {
        ClassifierLog.setText("N");                                                 
        System.out.println("COUNTER: " + counter);

    } else {
        counter++;                                         
        ClassifierLog.setText("W");                                               
        System.out.println("COUNTER: " + counter);
    }
}

上述代码的问题在于,我总是得到相同的数字标签1(我预测相同的数字),而不是获得不同的输出01。据我所知,我需要将所有浮点值附加到单个数组中。虽然我尝试将这样的数组传递给预测方法,但它无法正常工作。但是它不起作用,这是使用predict类中的Predictor方法的正确方法吗?

1 个答案:

答案 0 :(得分:3)

获得解决方案的最佳方法是构建&amp;提供一个非常小的java类,它运行并且只包含您希望解决的问题的基本要素(包括输入)。虽然这需要您花费少量精力才能做到这一点,但它会清楚地表达您想要完成的任务。 请参阅下面的示例。

我最好的猜测是你只想传递一个包含三个值的数组并返回一个结果。如果是这种情况,那么您的第一个代码示例显示您已经正确地执行了此操作。

因此,您始终返回1的问题在于Predictor代码本身的逻辑;即sigma / theta逻辑(你没有传达它们的值)。但是,知道正确传入/传出值意味着您只需要调试sigma / theta逻辑。例如,在highestLikeli循环中的if语句之后,添加一个sysout:


    System.out.format("%d %f %f\n",classIndex,highestLikeli,likelihoods[i]);

作为参考,这是一个小的java文件,显示传入的数组值和传递的预测值(这是你的问题似乎是关于)。



    public class Test {

        public static void main(String[] args) {
            String params[][] = { { "3.3", "4.4", "5.5" }, { "5.5", "3.3", "4.4" }, { "4.4", "5.5", "3.3" } };
            for (int i = 0; i < params[0].length; i++) {
                System.out.format("%d: %s, %s, %s => ", i, params[i][0], params[i][1], params[i][2]);
                Predictor.main(params[i]);
            }
        }

        static class Predictor {

            // Simplistic example taking 3 values and returning the index of the largest.
            public static int predict(double[] atts) {
                double highestLikeli = Double.NEGATIVE_INFINITY;
                int classIndex = -1;
                for (int i = 0; i < atts.length; i++) {
                    if (atts[i] > highestLikeli) {
                        highestLikeli = atts[i];
                        classIndex = i;
                    }
                }
                return classIndex;
            }

            public static void main(String[] args) {
                if (args.length == 3) {
                    double[] atts = new double[args.length];
                    for (int i = 0; i<args.length; i++) {
                        atts[i] = Double.parseDouble(args[i]);
                    }
                    System.out.println(Predictor.predict(atts));
                }
            }
        }
    }

结果


    0: 3.3, 4.4, 5.5 => 2
    1: 5.5, 3.3, 4.4 => 0
    2: 4.4, 5.5, 3.3 => 1