
时间:2016-06-18 22:09:03

标签: java logistic-regression gradient-descent




public class GradDescent {
    public static void main(String[] args) {
        double x1[] = {0.0, 0.5, 1.0, 1.5, 2.0, 0.1, 3.0, 3.1, 3.5, 3.2, 2.5, 2.8};
        double x2[] = {0.0, 1.0, 1.1, 0.5, 0.3, 2.0, 3.0, 0.3, 1.5, 2.2, 3.6, 2.8};
        double y[] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
        int counter = 0;

     /*number of examples*/
        int m = 12;

      /*thetas and temp thetas*/
        double theta0, temp0, theta1, temp1, theta2, temp2;
        theta0 = 0.0;
        temp0 = 0.0;
        theta1 = 0.0;
        temp1 = 0.0;
        theta2 = 0.0;
        temp2 = 0.0;

    /*# of iterations and learning rate*/
        int iterations = 1819800;
        float alpha = 0.009f;

        int j = 0;
        double h0 = 0.0;
        double h1 = 0.0;
        double h2 = 0.0;

        int i = 0;
        for (i = 0; i < iterations; i++) {

            h0 = 0.0;
            h1 = 0.0;
            h2 = 0.0;

            for (j = 0; j < m; j++) {
                h0 = h0 + (sigmoid(theta0 + x1[j] * theta1 + x2[j] * theta2) - y[j]);
                h1 = h1 + (sigmoid(theta0 + x1[j] * theta1 + x2[j] * theta2) - y[j]) * x1[j];
                h2 = h2 + (sigmoid(theta0 + x1[j] * theta1 + x2[j] * theta2) - y[j]) * x2[j];
            temp0 = theta0 - (alpha * h0) / (double) m;
            temp1 = theta1 - (alpha * h1) / (double) m;
            temp2 = theta2 - (alpha * h2) / (double) m;
            theta0 = temp0;
            theta1 = temp1;
            theta2 = temp2;

            counter = counter + 1;
            if (counter < 1000) {
                for (j = 0; j < m; j++) {
                    h0 = h0 + y[j] * Math.log(sigmoid(theta0 + x1[j] * theta1 + x2[j] * theta2)) + (1 - y[j]) * (1 - (sigmoid(theta0 + x1[j] * theta1 + x2[j] * theta2)));                                                            //+ Math.pow(( sigmoid(theta0 + x1[j]*theta1 + x2[j]*theta1) - y[j]), 2.0);
                h0 = (h0 / m) * -1;
                float[][] cost = {{(float) counter, (float) h0}};

                System.out.println("Cost at " + counter + " is " + h0);



        System.out.println(theta2 + "x2 + " + theta1 + "x1 + " + theta0);

        testGradientDescent(2f, theta0, theta1, theta2);


    private static double sigmoid(double x) {
        return 1 / (1 + Math.exp(-1 * x));

    private static void testGradientDescent(float n, double theta0, double theta1, double theta2) {
        //double result = theta0 + (theta1*n) + (theta2*n);
        double x3[][] = {{0.0, 0.0}, {0.5, 1.0}, {1.0, 1.1}, {1.5, 0.5}, {2.0, 0.3}, {0.1, 2.0}};
        double x4[][] = {{3.0, 3.0}, {3.1, 0.3}, {3.5, 1.5}, {3.2, 2.2}, {2.5, 3.6}, {2.8, 2.8}};

        //System.out.println("Result: " + result);
        String outputFunction = String.valueOf(String.valueOf(theta2) + "*x2+" + String.valueOf(theta1) + "*x1+" + String.valueOf(theta0));

        System.out.println("Plotting " + outputFunction);
        //JavaPlot p = new JavaPlot();
        //p.set("title", "'Gradient Descent'");
        //p.set("xrange", "[0:4]");
        // p.addPlot(outputFunction);



0 个答案:
