Mahout - 简单的分类问题

时间:2012-06-25 12:32:01

标签: java classification mahout

我正在尝试构建一个简单的模型,可以将点分类为2D空间的2个分区

  1. 训练模型,指定少数点及其所属的分区
  2. 我使用该模型来预测测试点可能落下的组(分类)
  3. 不幸的是,我没有按预期得到答案。我在代码中遗漏了什么,或者我做错了什么?

    public class SimpleClassifier {
    
        public static class Point{
            public int x;
            public int y;
    
            public Point(int x,int y){
                this.x = x;
                this.y = y;
            }
    
            @Override
            public boolean equals(Object arg0) {
                Point p = (Point)  arg0;
                return( (this.x == p.x) &&(this.y== p.y));
            }
    
            @Override
            public String toString() {
                // TODO Auto-generated method stub
                return  this.x + " , " + this.y ; 
            }
        }
    
        public static void main(String[] args) {
    
            Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();
    
            points.put(new Point(0,0), 0);
            points.put(new Point(1,1), 0);
            points.put(new Point(1,0), 0);
            points.put(new Point(0,1), 0);
            points.put(new Point(2,2), 0);
    
    
            points.put(new Point(8,8), 1);
            points.put(new Point(8,9), 1);
            points.put(new Point(9,8), 1);
            points.put(new Point(9,9), 1);
    
    
            OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
            learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
            learningAlgo.learningRate(50);
    
            //learningAlgo.alpha(1).stepOffset(1000);
    
            System.out.println("training model  \n" );
            for(Point point : points.keySet()){
                Vector v = getVector(point);
                System.out.println(point  + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            }
    
            learningAlgo.close();
    
    
            //now classify real data
            Vector v = new RandomAccessSparseVector(2);
            v.set(0, 0.5);
            v.set(1, 0.5);
    
            Vector r = learningAlgo.classifyFull(v);
            System.out.println(r);
    
            System.out.println("ans = " );
            System.out.println("no of categories = " + learningAlgo.numCategories());
            System.out.println("no of features = " + learningAlgo.numFeatures());
            System.out.println("Probability of cluster 0 = " + r.get(0));
            System.out.println("Probability of cluster 1 = " + r.get(1));
    
        }
    
        public static Vector getVector(Point point){
            Vector v = new DenseVector(2);
            v.set(0, point.x);
            v.set(1, point.y);
    
            return v;
        }
    }
    

    输出:

    ans = 
    no of categories = 2
    no of features = 2
    Probability of cluster 0 = 3.9580985042775296E-4
    Probability of cluster 1 = 0.9996041901495722
    

    99%的输出显示cluster 1的概率更高。的为什么吗

2 个答案:

答案 0 :(得分:5)

问题是你没有包含偏见(拦截)术语,它总是1。 您需要将偏差项(1)添加到您的点类中。

这是许多有经验的人在机器学习中犯下的一个非常基本的错误。在学习理论上投入一些时间可能是个好主意。 Andrew Ng's lectures是一个值得学习的好地方。

要让代码给出预期的输出,需要更改以下内容。

  1. 偏见增加。
  2. 学习参数太高。将其更改为10
  3. 现在你将获得0级的P(0)= 0.9999。

    这是一个完整的工作示例,可以提供正确的结果:

    import java.util.HashMap;
    import java.util.Map;
    
    import org.apache.mahout.classifier.sgd.L1;
    import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
    import org.apache.mahout.math.DenseVector;
    import org.apache.mahout.math.RandomAccessSparseVector;
    import org.apache.mahout.math.Vector;
    
    
    class Point{
        public int x;
        public int y;
    
        public Point(int x,int y){
            this.x = x;
            this.y = y;
        }
    
        @Override
        public boolean equals(Object arg0) {
            Point p = (Point)  arg0;
            return( (this.x == p.x) &&(this.y== p.y));
        }
    
        @Override
        public String toString() {
            return  this.x + " , " + this.y ; 
        }
    }
    
    public class SimpleClassifier {
    
    
    
        public static void main(String[] args) {
    
                Map<Point,Integer> points = new HashMap<Point, Integer>();
    
                points.put(new Point(0,0), 0);
                points.put(new Point(1,1), 0);
                points.put(new Point(1,0), 0);
                points.put(new Point(0,1), 0);
                points.put(new Point(2,2), 0);
    
                points.put(new Point(8,8), 1);
                points.put(new Point(8,9), 1);
                points.put(new Point(9,8), 1);
                points.put(new Point(9,9), 1);
    
    
                OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
                learningAlgo =  new OnlineLogisticRegression(2, 3, new L1());
                learningAlgo.lambda(0.1);
                learningAlgo.learningRate(10);
    
                System.out.println("training model  \n" );
    
                for(Point point : points.keySet()){
    
                    Vector v = getVector(point);
                    System.out.println(point  + " belongs to " + points.get(point));
                    learningAlgo.train(points.get(point), v);
                }
    
                learningAlgo.close();
    
                Vector v = new RandomAccessSparseVector(3);
                v.set(0, 0.5);
                v.set(1, 0.5);
                v.set(2, 1);
    
                Vector r = learningAlgo.classifyFull(v);
                System.out.println(r);
    
                System.out.println("ans = " );
                System.out.println("no of categories = " + learningAlgo.numCategories());
                System.out.println("no of features = " + learningAlgo.numFeatures());
                System.out.println("Probability of cluster 0 = " + r.get(0));
                System.out.println("Probability of cluster 1 = " + r.get(1));
    
        }
    
        public static Vector getVector(Point point){
            Vector v = new DenseVector(3);
            v.set(0, point.x);
            v.set(1, point.y);
            v.set(2, 1);
            return v;
        }
    }
    

    输出:

    2 , 2 belongs to 0
    1 , 0 belongs to 0
    9 , 8 belongs to 1
    8 , 8 belongs to 1
    0 , 1 belongs to 0
    0 , 0 belongs to 0
    1 , 1 belongs to 0
    9 , 9 belongs to 1
    8 , 9 belongs to 1
    {0:2.470723149516907E-6,1:0.9999975292768505}
    ans = 
    no of categories = 2
    no of features = 3
    Probability of cluster 0 = 2.470723149516907E-6
    Probability of cluster 1 = 0.9999975292768505
    

    请注意,我在SimpleClassifier类之外定义了Point类,但这只是为了使代码更具可读性并且不是必需的。

    了解更改学习率时会发生什么。阅读有关交叉验证的说明,了解如何选择学习率。

    Learning Rate => Probability of cluster 0
    0.001 => 0.4991116089
    0.01 => 0.492481585
    0.1 => 0.469961472
    1 => 0.5327745322
    10 => 0.9745740393
    100 => 0
    1000 => 0
    

    选择学习率:

    1. 通过以固定学习率α开始,通过缓慢地让学习率α减小到零来进行随机梯度下降是常见的。 算法运行时,也可以确保参数收敛到 全局最小值而不仅仅是在最小值附近振荡。
    2. 在这种情况下,当我们使用常数α时,您可以进行初始选择,运行梯度下降并观察成本函数,并相应地调整学习率。它解释为here

答案 1 :(得分:1)

我认为我认为你的分类示例有潜在的问题

  • 使用OnlineLogisticRegression培训(learningRate等等)
  • 的默认值
  • 引入恒定偏差(它只是具有常数值1的另一个预测变量)
  • 随机播放训练数据(首先不提供与第一组相对应的训练数据,然后提供第二组中提供的数据)
  • 显着增加培训数据

有关此潜在问题的详细信息,请参阅书籍Mahout in Action

&#34;修复&#34;

结果潜在问题:
测试点<0.5, 0.5>被分类为cluster 0,概率为ca. 0.89始终贯穿多次运行 这听起来像是一个合理的输出,因为原点附近的其他点(用于训练模型)也属于cluster 0

<强>代码

public class SimpleClassifier {

    public static class Point {
        public int x;
        public int y;

        public Point(int x, int y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point) arg0;
            return ((this.x == p.x) && (this.y == p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return this.x + " , " + this.y;
        }
    }

    public static void main(String[] args) {

        Map<Point, Integer> points = new HashMap<Point, Integer>();

        points.put(new Point(0, 0), 0);
        points.put(new Point(1, 1), 0);
        points.put(new Point(1, 0), 0);
        points.put(new Point(0, 1), 0);
        points.put(new Point(2, 2), 0);


        points.put(new Point(8, 8), 1);
        points.put(new Point(8, 9), 1);
        points.put(new Point(9, 8), 1);
        points.put(new Point(9, 9), 1);


        OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression(2, 3, new L1());

        System.out.println("training model  \n");
        for (int i=0; i<100; i++) {
            List<Point> randomPoints = new ArrayList<>(points.keySet());
            Collections.shuffle(randomPoints);
            for (Point point : randomPoints) {
                Vector v = getVector(point);
                System.out.println(point + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            }
        }
        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(3);
        v.set(0, 0.5);
        v.set(1, 0.5);
        v.set(2, 1);

        Vector r = learningAlgo.classify(v);
        System.out.println(r);

        System.out.println("ans = ");
        System.out.println("no of categories = " + learningAlgo.numCategories());
        System.out.println("no of features = " + learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + (1.0d - r.get(0)));
        System.out.println("Probability of cluster 1 = " + r.get(0));

    }

    public static Vector getVector(Point point) {
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);

        return v;
    }
}