无法在JAVA中为WEKA实现简单k-means的余弦相似性

时间:2014-03-28 18:49:47

标签: cluster-analysis data-mining weka cosine-similarity

我对Java的ML的WEKA API很新。

由于weka中没有余弦相似度算法,我想通过修改WEKA的simpleKmeans算法将此算法添加到WEKA。

weka中的simpleKmeans算法使用EuclideanDistance,我希望使用余弦相似度而不是euclideanDistance。

我搜索了很多关于如何修改WEKA开源软件的simpleKmeans算法的代码,并在网上发现了这个问题(基本上是pedro的观点)

http://comments.gmane.org/gmane.comp.ai.weka/22681

这里提到的步骤是:

  1. 扩展weka.core.EuclideanDistance并覆盖距离(实例优先, 实例第二,PerformanceStats统计数据)方法。

  2. 使用EuclideanDistance作为类型将其实例化为扩展类,  将Instances作为扩展类构造函数的参数传递。

  3. 使用setDistanceFunction类传递的SimpleKMeans方法 EuclideanDistance实例。

  4. 这是WEKA流程第一部分的代码。

    /*
     * To change this license header, choose License Headers in Project Properties.
     * To change this template file, choose Tools | Templates
     * and open the template in the editor.
     */
    
    package weka.core;
    
    import weka.core.Attribute;
    //import weka.core.EuclideanDistance;
    import java.util.Enumeration;
    import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.converters.ConverterUtils.DataSource;
    import weka.core.neighboursearch.PerformanceStats;
    import weka.core.TechnicalInformation.Type;
    
    /**
     *
     * @author Sgr
     */
    public class CosineSimilarity extends EuclideanDistance{
    
     public Instances m_Data = null;
     public String version ="1.0";
    
     @Override
     public double distance(Instance arg0, Instance arg1) {
      // TODO Auto-generated method stub
      return distance(arg0, arg1, Double.POSITIVE_INFINITY, null);
     }
    
     @Override
     public double distance(Instance arg0, Instance arg1, PerformanceStats arg2) {
      // TODO Auto-generated method stub
      return distance(arg0, arg1, Double.POSITIVE_INFINITY, arg2);
     }
    
     @Override
     public double distance(Instance arg0, Instance arg1, double arg2) {
      // TODO Auto-generated method stub
      return distance(arg0, arg1, arg2, null);
     }
    
     @Override
     public double distance(Instance first, Instance second, double cutOffValue,PerformanceStats arg3) {
    
        double distance = 0;
        int firstI, secondI;
        int firstNumValues = first.numValues();
        int secondNumValues = second.numValues();
        int numAttributes = m_Data.numAttributes();
        int classIndex = m_Data.classIndex();
        double normA, normB;
        normA = 0;
        normB = 0;
    
        for (int p1 = 0, p2 = 0; p1 < firstNumValues || p2 < secondNumValues;) {
    
            if (p1 >= firstNumValues)
                firstI = numAttributes;
            else firstI = first.index(p1);
    
    
            if (p2 >= secondNumValues)
                secondI = numAttributes;
            else secondI = second.index(p2);
    
            if (firstI == classIndex) {
                p1++;
               continue;
            }
    //   if ((firstI < numAttributes)) {
    //    p1++;
    //    continue;
    //   }
    
            if (secondI == classIndex) {
                p2++;
                continue;
            }
    //   if ((secondI < numAttributes)) {
    //    p2++;
    //    continue;
    //   }
    
            double diff;
    
            if (firstI == secondI) {
    
                diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2));
                normA += Math.pow(first.valueSparse(p1), 2);
                normB += Math.pow(second.valueSparse(p2), 2);
                p1++;
                p2++;
    
            } 
    
            else if (firstI > secondI) {
    
                diff = difference(secondI, 0, second.valueSparse(p2));
                normB += Math.pow(second.valueSparse(p2), 2);
                p2++;
    
            }
    
            else {
                diff = difference(firstI, first.valueSparse(p1), 0);
                normA += Math.pow(first.valueSparse(p1), 2);
                p1++;
            }
    
            if (arg3 != null)
                arg3.incrCoordCount();
    
            distance = updateDistance(distance, diff);
    
            if (distance > cutOffValue)
                return Double.POSITIVE_INFINITY;
            }
    
      //do the post here, don't depends on other functions
      //System.out.println(distance + " " + normA + " "+ normB);
            distance = distance/Math.sqrt(normA)/Math.sqrt(normB);
            distance = 1-distance;
    
            if(distance < 0 || distance > 1)
                System.err.println("unknown: " + distance);
    
            return distance;
    
        }
    
     public double updateDistance(double currDist, double diff){
    
         double result;
        result = currDist;
        result += diff;
    
        return result;
     }
    
     public double difference(int index, double val1, double val2){
    
         switch(m_Data.attribute(index).type()){
    
             case Attribute.NOMINAL:
                                return Double.NaN;
                                //break;
             case Attribute.NUMERIC:
                                  return val1 * val2;
                                //break;
        }
    
         return Double.NaN;
     }
    
     @Override
     public String getAttributeIndices() {
      // TODO Auto-generated method stub
      return null;
     }
    
     @Override
     public Instances getInstances() {
      // TODO Auto-generated method stub
      return m_Data;
     }
    
     @Override
     public boolean getInvertSelection() {
      // TODO Auto-generated method stub
      return false;
     }
    
     @Override
     public void postProcessDistances(double[] arg0) {
      // TODO Auto-generated method stub
    
     }
    
     @Override
     public void setAttributeIndices(String arg0) {
      // TODO Auto-generated method stub
    
     }
    
     @Override
     public void setInstances(Instances arg0) {
      // TODO Auto-generated method stub
      m_Data = arg0;
     }
    
     @Override
     public void setInvertSelection(boolean arg0) {
      // TODO Auto-generated method stub
    
    
      //do nothing
     }
    
     @Override
     public void update(Instance arg0) {
      // TODO Auto-generated method stub
    
      //do nothing
     }
    
     @Override
     public String[] getOptions() {
      // TODO Auto-generated method stub
      return null;
     }
    
     @Override
     public Enumeration listOptions() {
      // TODO Auto-generated method stub
      return null;
     }
    
     @Override
     public void setOptions(String[] arg0) throws Exception {
      // TODO Auto-generated method stub
    
     }
    
     @Override
     public String getRevision() {
      // TODO Auto-generated method stub
      return "Cosine Distance function writtern by Sgr, version " + version;
     }
    
    
    }
    

    但由于我不熟悉weka,我无法处理接下来的两个步骤。

    我在weka中看到了simpleKmeans的源代码,并观察到它创建了EuclideanDistance类的实例,但我对进一步的过程一无所知。

    请帮助我完成接下来的两个步骤。如果在余弦相似度的实现中存在错误,请仔细弄清楚。此外,如果任何人都可以在weka中修改SimpleKmeans的代码以用于我的余弦实现,或者向我解释我应该在该代码中进行更改的位置,那将非常有用。

1 个答案:

答案 0 :(得分:1)

在集群方面,Weka真的是。它也很慢。

你看过ELKI了吗?与Weka相比,它在聚类和异常检测方面有更多选择。您可以在ELKI中以k-means开箱即用的方式试验余弦相似性。

但请注意, k-means不是基于距离的。它最小化方差(平方和),如果使用其他距离函数, k-means可能会停止收敛。原因是均值是L2最优中心,但优化其他距离函数。它只是优化平方和,与平方欧几里德距离相同。

通常,具有其他距离的k-means(例如余弦)可以工作并收敛您的数据集。但收敛证明需要平方和。实际上,当簇的平均值变为0时(即使您的数据不包含零向量),使用具有余弦相似性的k均值也可能产生0除法误差。

有许多变体,例如k-medoids,支持其他距离函数。据我所知,它们也应该在ELKI中可用。