在golearn中使用.Predict()函数时获取<nil>

时间:2016-01-26 01:21:10

标签: go machine-learning

我正在使用golearn examples文件夹中的knnclassifier_iris.go示例。我用我自己的一个替换了虹膜数据集,只要我在读入的数据的一定百分比上训练我的数据,所有函数都能正常工作,我得到一些输出。但是,当我清楚地提到训练和测试数据集,然后在拟合训练数据集后对测试数据集运行预测时,当我尝试打印预测时,我得到一个零结果。我不知道为什么我得到零值,所以我真的很感激一些帮助。

我的代码:

package main

import (
    "fmt"
    "github.com/sjwhitworth/golearn/base"
    "github.com/sjwhitworth/golearn/evaluation"
    "github.com/sjwhitworth/golearn/knn"
)

func main() {
    trainData, err := base.ParseCSVToInstances("~/Desktop/churn_train.csv", true)
    if err != nil {
        panic(err)
    }
    fmt.Println(trainData)
    testData, err := base.ParseCSVToInstances("~/Desktop/churn_test.csv", false)
    if err != nil {
        panic(err)
    }
    fmt.Println(trainData)
    fmt.Println(testData)

    //Initialises a new KNN classifier
    cls := knn.NewKnnClassifier("euclidean", 2)
    cls.Fit(trainData)

//Calculates the Euclidean distance and returns the most popular label
    predictions := cls.Predict(testData)
    fmt.Println(predictions) //GETTING <NIL> AS OUTPUT

    // Prints precision/recall metrics
    confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
    if err != nil {
        panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error())) //ERROR CAUSED HERE DUE TO GETTING <NIL>
    }
    fmt.Println(evaluation.GetSummary(confusionMat))

}

1 个答案:

答案 0 :(得分:0)

(Just in case anybody's stumbled across this on Google). The issue tends to arise when the second ParseCSVToInstances produces instances which are subtly different from the first. To ensure that this isn't the problem, use ParseCSVToTemplatedInstances, so

testData, err := base.ParseCSVToInstances("~/Desktop/churn_test.csv", false)

becomes

 testData, err := base.ParseCSVToTemplatedInstances("~/Desktop/churn_test.csv", false, trainData)