尝试使用带有Weka-No输出实例格式的Java对新实例进行分类时出错

时间:2014-11-28 19:25:31

标签: java classification weka

我试图在我的项目中使用Weka来使用NaïveBayes分类器对文本文档进行分类。我在this site找到了以下两个课程。

第一个类MyFilteredLearner构建,训练,评估并将分类器保存到磁盘,这一切都正常。

第二个类MyFilteredClassifier从文本文件加载单个文本字符串,并使其成功进入实例。它还从磁盘恢复分类器。它没有做的是使用方法classify()对实例进行分类,而是返回异常消息'No output instance format defined'。

我花了很长时间寻找答案,尝试安装开发人员和稳定版本的Weka,但仍然遇到同样的问题。

有人知道代码中的错误或需要以不同的方式添加/完成吗?文件详细信息和代码如下:

用于训练分类器的ARFF文件(spam.ARFF):

@relation sms_test

@attribute spamclass {spam,ham}
@attribute text String

@data
ham,'Go until jurong point, crazy.. Available only in bugis n great world la e buffet...Cine there got amore wat...'
etc……………………………………………………………………

新实例的单行文字文件(toClassify.txt):

this is spam or not, who knows?

代码MyFilteredLearner

public class MyFilteredLearner {
    Instances trainData;
    StringToWordVector filter;
    FilteredClassifier classifier;

    public void loadDataset(String fileName) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            ArffReader arff = new ArffReader(reader);
            trainData = arff.getData();
            System.out.println("===== Loaded dataset: " + fileName + " =====");
            reader.close();
        }
        catch (IOException e) {
            System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void learn() {
        try {
            trainData.setClassIndex(0);
            classifier = new FilteredClassifier();
            filter = new StringToWordVector();
            filter.setAttributeIndices("last");
            classifier.setFilter(filter);
            classifier.setClassifier(new NaiveBayes());
            classifier.buildClassifier(trainData);
            System.out.println("===== Training on filtered (training) dataset done =====");
        }
        catch (Exception e) {
            System.out.println("Problem found when training");
        }
    }

    public void evaluate() {
        try {
            trainData.setClassIndex(0);
            filter = new StringToWordVector();
            filter.setAttributeIndices("last");
            classifier = new FilteredClassifier();
            classifier.setFilter(filter);
            classifier.setClassifier(new NaiveBayes());
            Evaluation eval = new Evaluation(trainData);
            eval.crossValidateModel(classifier, trainData, 4, new Random(1));
            System.out.println(eval.toSummaryString());
            System.out.println(eval.toClassDetailsString());
            System.out.println("===== Evaluating on filtered (training) dataset done =====");
        }
        catch (Exception e) {
            System.out.println("Problem found when evaluating");
        }
    }

    public void saveModel(String fileName) {
        try {
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName));
            out.writeObject(classifier);
            System.out.println("Saved model: " + out.toString());
            out.close();
            System.out.println("===== Saved model: " + fileName + "=====");
            } 
        catch (IOException e) {
            System.out.println("Problem found when writing: " + fileName);
        }
    }
}

代码MyFilteredClassifier

public class MyFilteredClassifier {
    String text;
    Instances instances;
    FilteredClassifier classifier;  
    StringToWordVector filter;

    public void load(String fileName) {
        try {
            BufferedReader reader = new BufferedReader(new FileReader(fileName));
            String line;
            text = "";
            while ((line = reader.readLine()) != null) {
                        text = text + " " + line;
                }
            System.out.println("===== Loaded text data: " + fileName + " =====");
            reader.close();
            System.out.println(text);
        }
        catch (IOException e) {
            System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void makeInstance() {
        FastVector fvNominalVal = new FastVector(2);
        fvNominalVal.addElement("spam");
        fvNominalVal.addElement("ham");
        Attribute attribute1 = new Attribute("class", fvNominalVal);
        Attribute attribute2 = new Attribute("text",(FastVector) null);
        FastVector fvWekaAttributes = new FastVector(2);
        fvWekaAttributes.addElement(attribute1);
        fvWekaAttributes.addElement(attribute2);
        instances = new Instances("Test relation", fvWekaAttributes,1);           
        instances.setClassIndex(0);
        DenseInstance instance = new DenseInstance(2);
        instance.setValue(attribute2, text);
        instances.add(instance);
        System.out.println("===== Instance created with reference dataset =====");
        System.out.println(instances);
    }

    public void loadModel(String fileName) {
        try {
            ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
            Object tmp = in.readObject();
            classifier = (FilteredClassifier) tmp;
            in.close();
            System.out.println("===== Loaded model: " + fileName + "=====");
        } 
        catch (Exception e) {
        System.out.println("Problem found when reading: " + fileName);
        }
    }

    public void classify() {
        try {
            double pred = classifier.classifyInstance(instances.instance(0));
            System.out.println("===== Classified instance =====");
            System.out.println("Class predicted: " + instances.classAttribute().value((int) pred));
        }
        catch (Exception e) {
            System.out.println("Error: " + e.getMessage());
        }       
    }

    public static void main(String args[]) {
        MyFilteredLearner c = new MyFilteredLearner();
        c.loadDataset("spam.ARFF");
        c.learn();
        c.evaluate();
        c.saveModel("spamClassifier.binary");
        MyFilteredClassifier c1 = new MyFilteredClassifier();
        c1.load("toClassify.txt");
        c1.loadModel("spamClassifier.binary");
        c1.makeInstance();
        c1.classify();
    }

}

1 个答案:

答案 0 :(得分:1)

您似乎在一个细节中更改了博客的GitHub存储库中的代码,这是导致错误的原因:

c.learn();
c.evaluate();

VS

c.evaluate();
c.learn();

evaluate()方法使用以下行重置分类器:

classifier = new FilteredClassifier();

但不构建模型。实际评估使用传递的分类器的副本,因此原始分类器(类中的分类器)保持未经训练。

// weka/classifiers/Evaluation.java (method: crossValidateModel)
Classifier copiedClassifier = Classifier.makeCopy(classifier);
copiedClassifier.buildClassifier(train);

因此,您首先构建模型,然后在评估模型时覆盖它,然后保存未初始化的模型。切换它们以便在将其保存到文件之前直接进行训练,然后就可以了。