我正在尝试构建一个简单的模型,可以将点分类为2D空间的2个分区:
不幸的是,我没有按预期得到答案。我在代码中遗漏了什么,或者我做错了什么?
public class SimpleClassifier {
public static class Point{
public int x;
public int y;
public Point(int x,int y){
this.x = x;
this.y = y;
}
@Override
public boolean equals(Object arg0) {
Point p = (Point) arg0;
return( (this.x == p.x) &&(this.y== p.y));
}
@Override
public String toString() {
// TODO Auto-generated method stub
return this.x + " , " + this.y ;
}
}
public static void main(String[] args) {
Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();
points.put(new Point(0,0), 0);
points.put(new Point(1,1), 0);
points.put(new Point(1,0), 0);
points.put(new Point(0,1), 0);
points.put(new Point(2,2), 0);
points.put(new Point(8,8), 1);
points.put(new Point(8,9), 1);
points.put(new Point(9,8), 1);
points.put(new Point(9,9), 1);
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
learningAlgo = new OnlineLogisticRegression(2, 2, new L1());
learningAlgo.learningRate(50);
//learningAlgo.alpha(1).stepOffset(1000);
System.out.println("training model \n" );
for(Point point : points.keySet()){
Vector v = getVector(point);
System.out.println(point + " belongs to " + points.get(point));
learningAlgo.train(points.get(point), v);
}
learningAlgo.close();
//now classify real data
Vector v = new RandomAccessSparseVector(2);
v.set(0, 0.5);
v.set(1, 0.5);
Vector r = learningAlgo.classifyFull(v);
System.out.println(r);
System.out.println("ans = " );
System.out.println("no of categories = " + learningAlgo.numCategories());
System.out.println("no of features = " + learningAlgo.numFeatures());
System.out.println("Probability of cluster 0 = " + r.get(0));
System.out.println("Probability of cluster 1 = " + r.get(1));
}
public static Vector getVector(Point point){
Vector v = new DenseVector(2);
v.set(0, point.x);
v.set(1, point.y);
return v;
}
}
输出:
ans =
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722
99%的输出显示cluster 1
的概率更高。的为什么吗
答案 0 :(得分:5)
问题是你没有包含偏见(拦截)术语,它总是1。 您需要将偏差项(1)添加到您的点类中。
这是许多有经验的人在机器学习中犯下的一个非常基本的错误。在学习理论上投入一些时间可能是个好主意。 Andrew Ng's lectures是一个值得学习的好地方。
要让代码给出预期的输出,需要更改以下内容。
现在你将获得0级的P(0)= 0.9999。
这是一个完整的工作示例,可以提供正确的结果:
import java.util.HashMap;
import java.util.Map;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
class Point{
public int x;
public int y;
public Point(int x,int y){
this.x = x;
this.y = y;
}
@Override
public boolean equals(Object arg0) {
Point p = (Point) arg0;
return( (this.x == p.x) &&(this.y== p.y));
}
@Override
public String toString() {
return this.x + " , " + this.y ;
}
}
public class SimpleClassifier {
public static void main(String[] args) {
Map<Point,Integer> points = new HashMap<Point, Integer>();
points.put(new Point(0,0), 0);
points.put(new Point(1,1), 0);
points.put(new Point(1,0), 0);
points.put(new Point(0,1), 0);
points.put(new Point(2,2), 0);
points.put(new Point(8,8), 1);
points.put(new Point(8,9), 1);
points.put(new Point(9,8), 1);
points.put(new Point(9,9), 1);
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
learningAlgo = new OnlineLogisticRegression(2, 3, new L1());
learningAlgo.lambda(0.1);
learningAlgo.learningRate(10);
System.out.println("training model \n" );
for(Point point : points.keySet()){
Vector v = getVector(point);
System.out.println(point + " belongs to " + points.get(point));
learningAlgo.train(points.get(point), v);
}
learningAlgo.close();
Vector v = new RandomAccessSparseVector(3);
v.set(0, 0.5);
v.set(1, 0.5);
v.set(2, 1);
Vector r = learningAlgo.classifyFull(v);
System.out.println(r);
System.out.println("ans = " );
System.out.println("no of categories = " + learningAlgo.numCategories());
System.out.println("no of features = " + learningAlgo.numFeatures());
System.out.println("Probability of cluster 0 = " + r.get(0));
System.out.println("Probability of cluster 1 = " + r.get(1));
}
public static Vector getVector(Point point){
Vector v = new DenseVector(3);
v.set(0, point.x);
v.set(1, point.y);
v.set(2, 1);
return v;
}
}
输出:
2 , 2 belongs to 0
1 , 0 belongs to 0
9 , 8 belongs to 1
8 , 8 belongs to 1
0 , 1 belongs to 0
0 , 0 belongs to 0
1 , 1 belongs to 0
9 , 9 belongs to 1
8 , 9 belongs to 1
{0:2.470723149516907E-6,1:0.9999975292768505}
ans =
no of categories = 2
no of features = 3
Probability of cluster 0 = 2.470723149516907E-6
Probability of cluster 1 = 0.9999975292768505
请注意,我在SimpleClassifier类之外定义了Point类,但这只是为了使代码更具可读性并且不是必需的。
了解更改学习率时会发生什么。阅读有关交叉验证的说明,了解如何选择学习率。
Learning Rate => Probability of cluster 0
0.001 => 0.4991116089
0.01 => 0.492481585
0.1 => 0.469961472
1 => 0.5327745322
10 => 0.9745740393
100 => 0
1000 => 0
选择学习率:
答案 1 :(得分:1)
我认为我认为你的分类示例有潜在的问题:
OnlineLogisticRegression
培训(learningRate
等等)1
的另一个预测变量)有关此潜在问题的详细信息,请参阅书籍Mahout in Action。
&#34;修复&#34; 结果潜在问题:
测试点<0.5, 0.5>
被分类为cluster 0
,概率为ca. 0.89
始终贯穿多次运行
这听起来像是一个合理的输出,因为原点附近的其他点(用于训练模型)也属于cluster 0
。
<强>代码强>
public class SimpleClassifier {
public static class Point {
public int x;
public int y;
public Point(int x, int y) {
this.x = x;
this.y = y;
}
@Override
public boolean equals(Object arg0) {
Point p = (Point) arg0;
return ((this.x == p.x) && (this.y == p.y));
}
@Override
public String toString() {
// TODO Auto-generated method stub
return this.x + " , " + this.y;
}
}
public static void main(String[] args) {
Map<Point, Integer> points = new HashMap<Point, Integer>();
points.put(new Point(0, 0), 0);
points.put(new Point(1, 1), 0);
points.put(new Point(1, 0), 0);
points.put(new Point(0, 1), 0);
points.put(new Point(2, 2), 0);
points.put(new Point(8, 8), 1);
points.put(new Point(8, 9), 1);
points.put(new Point(9, 8), 1);
points.put(new Point(9, 9), 1);
OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(2, 3, new L1());
System.out.println("training model \n");
for (int i=0; i<100; i++) {
List<Point> randomPoints = new ArrayList<>(points.keySet());
Collections.shuffle(randomPoints);
for (Point point : randomPoints) {
Vector v = getVector(point);
System.out.println(point + " belongs to " + points.get(point));
learningAlgo.train(points.get(point), v);
}
}
learningAlgo.close();
//now classify real data
Vector v = new RandomAccessSparseVector(3);
v.set(0, 0.5);
v.set(1, 0.5);
v.set(2, 1);
Vector r = learningAlgo.classify(v);
System.out.println(r);
System.out.println("ans = ");
System.out.println("no of categories = " + learningAlgo.numCategories());
System.out.println("no of features = " + learningAlgo.numFeatures());
System.out.println("Probability of cluster 0 = " + (1.0d - r.get(0)));
System.out.println("Probability of cluster 1 = " + r.get(0));
}
public static Vector getVector(Point point) {
Vector v = new DenseVector(3);
v.set(0, point.x);
v.set(1, point.y);
v.set(2, 1);
return v;
}
}