我的线性下降的梯度下降实现如下所示:
private static void byGradientDescent(double[][] dataX, double[][] dataY){
double alpha = 0.05;
int m = dataX[0].length ;//variable number;
int n = dataX.length; //sample number
System.out.println(n+"\t"+m);
double[] thetas = new double[m + 1]; //thetas[m] is the intercept
for(int i = 0;i<thetas.length;i++) //initialize
thetas[i] = 0.5;
double[] derivatives = new double[m + 1];
boolean flag = false;
double lastRSS = 0;
do {
for(int i = 0;i<derivatives.length;i++)
derivatives[i] = 0;
double RSS = 0;
for (int i = 0; i < n; i++) { //calculate derivatives
double diff = thetas[m] ; //difference
for(int j = 0;j<m;j++)
diff += thetas[j] * dataX[i][j];
diff = diff - dataY[i][0];
RSS += diff*diff;
derivatives[m] += diff / n;
for(int j = 0;j<m;j++){
derivatives[j] = diff * dataX[i][j] /n;
}
}
for(int i = 0;i<thetas.length;i++) // update thetas
thetas[i] = thetas[i] - (alpha * derivatives[i]);
System.out.println(lastRSS - RSS);
lastRSS = RSS;
System.out.println(Arrays.toString(thetas));
flag = false;
for(double derivative : derivatives) // termination condition
flag = flag || (Math.abs(derivative)>0.01);
} while (flag);
}
但是,它仅适用于非常小的示例,但在Weka的cpu.arff数据集上失败。我为alpha尝试了不同的值,但仍然无法正常工作。例如,当alpha设置为0.00000005时,循环不会停止,而当其设置为0.0005时,将返回[Infinity,Infinity,NaN,NaN,NaN,NaN,Infinity,Infinity]。我不确定自己的实现是否有问题或使用方式有误。
下面是我用来求解数据集的代码。
public class LinearRegressionByLS {
public static void main(String arg[]) throws Exception {
double[][] X= {{5,0,10}, {1,1,9}, {1,2,10}, {1,3,12}};
double[][] Y = {{1},{2},{3},{5}};
byGradientDescent(X, Y);
DataSource source = new DataSource("D:\\Program Files\\Weka-3-8-4\\data\\cpu.arff");
Instances instances = source.getDataSet();
instances.setClassIndex(instances.numAttributes() - 1);
int n = instances.numInstances();
int m = instances.numAttributes() - 1;
double[][] dataX = new double[n][m];
double[][] dataY = new double[n][1];
for (int i = 0; i < n; i++) {
double[] values = instances.instance(i).toDoubleArray();
for (int j = 0; j < m; j++)
dataX[i][j] = values[j];
dataY[i][0] = values[m];
}
byGradientDescent(dataX, dataY);
}
}