BOBYQAOptimizer返回0权重-Java

时间:2018-09-30 09:40:41

标签: java apache optimization

我正在尝试使用org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer进行逻辑回归,但每个回归变量的权重为0.0。

我的数据集是4135行,如果不是男性则定位为0,如果是男性则定位为1;回归者的年龄,教育程度和拦截力(均为双打)

CSVRecord是一个案例类,用于读取每一行并将它们传递给yArray(目标)和xArray(功能)

优化器所需的MultivariateFunction实例作为封闭类的属性包括在内,并用在运行优化器之前调用的setter实例化。

其他一些家政记录: makeExponentialInnerProduct()是一种用于计算每一行的逻辑分布累积分布函数的辅助方法-此方法已在另一个程序中进行了检查,并且工作正常;
countLinesNew()是另一种帮助方法,用于计算从here提取的csv文件中的行数

下面是完整程序,调试器指示值已正确读取到数组和类字段中,但结果PointValuePair对象正在打印{0.0,0.0,0.0}。谢谢

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVRecord;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;

import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;

import static net.shanlodh.home.optim.FileLineCounter.countLinesNew;
import static net.shanlodh.home.optim.FileLineCounter.makeExponentialInnerProduct;

public class InnerClassOptimizer {

    public static void main(String[] args) throws IOException {
        int numLines = countLinesNew("/data.csv", true);

        double[] yArray = new double[numLines];
        double[][] xArray = new double[numLines][3];

        Reader in = new FileReader("/data.csv");

        Iterable<CSVRecord> records = CSVFormat.EXCEL.withHeader().parse(in);
        int lineNum = 0;
        for (CSVRecord record : records) {
            yArray[lineNum] = Double.parseDouble(record.get("male"));
            xArray[lineNum][0] = 1;
            xArray[lineNum][1] = Double.parseDouble(record.get("age"));
            xArray[lineNum][2] = Double.parseDouble(record.get("education"));
            lineNum++;
        }

        InnerClassOptimizer ico = new InnerClassOptimizer(yArray, xArray);

        ico.setFunc(yArray, xArray, ico.getInitialWeights());

        MultivariateFunction func = ico.getFunc();

        BOBYQAOptimizer optim = new BOBYQAOptimizer(10);
        PointValuePair result = optim.optimize(new MaxEval(1000),
                new ObjectiveFunction(func),
                GoalType.MINIMIZE,
                new InitialGuess(ico.getInitialWeights()),
                SimpleBounds.unbounded(3));
        for (int i = 0; i < result.getPoint().length; i++){
            System.out.printf("%f ", result.getPoint()[i]);
            System.out.printf("%f ", result.getPointRef()[i]);
            System.out.printf("%f ", result.getFirst()[i]);
            System.out.printf("%f ", result.getKey()[i]);
        }//above all prints 0.0
    }

    private double[] yArray;
    private double[][] xArray;
    private double[] initialWeights = new double[]{0.0, 0.0, 0.0};
    private MultivariateFunction func;

    public InnerClassOptimizer(double[] yArray, double[][] xArray) {
        this.yArray = yArray;
        this.xArray = xArray;
    }

    public MultivariateFunction getFunc() {
        return func;
    }

    public double[] getInitialWeights() {
        return initialWeights;
    }


    public void setFunc(double[] yArray, double[][] xArray, double[] initialWeights) {
        this.func = new MultivariateFunction() {
            @Override
            public double value(double[] doubles) {
                doubles = new double[yArray.length];
                for (int row = 0; row < doubles.length; row++) {
                    doubles[row] = yArray[row];
                }
                double logLikelihood = 0;
                for (int row = 0; row < doubles.length; row++) {
                    logLikelihood += (2 * doubles[row] - 1) * makeExponentialInnerProduct(
                            xArray[row], initialWeights);
                }
                return logLikelihood;
        }
        };
    }
}

0 个答案:

没有答案