我是否在k-fold cross_validation中使用相同的Tfidf词汇表

时间:2017-09-02 04:57:38

标签: python scikit-learn cross-validation tf-idf

我正在进行基于TF-IDF向量空间模型的文本分类。我只有不超过3000个样本。为了公平评估,我正在使用5倍交叉验证来评估分类器。但是有什么困惑我是在每次折叠交叉验证中是否有必要重建TF-IDF向量空间模型。也就是说,我是否需要重建词汇表并在每个折叠交叉验证中重新计算词汇表中的IDF值?

目前我正在基于scikit-learn工具包进行TF-IDF转换,并使用SVM训练我的分类器。我的方法如下:首先,我将手中的样本除以3:1的比例,75%的样本用于拟合TF-IDF向量空间模型的参数.Herein,参数是大小词汇及其中包含的术语,也是词汇表中每个术语的IDF值。然后我在这个TF-IDF SVM中转换余数并使用这些向量来制作5 - 交叉验证(值得注意的是,我不使用之前的75%样本进行转换)。

我的代码如下:

# train, test split, the train data is just for TfidfVectorizer() fit
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, train_size=0.75, random_state=0)
tfidf = TfidfVectorizer()
tfidf.fit(x_train)

# vectorizer test data for 5-fold cross-validation
x_test = tfidf.transform(x_test)

 scoring = ['accuracy']
 clf = SVC(kernel='linear')
 scores = cross_validate(clf, x_test, y_test, scoring=scoring, cv=5, return_train_score=False)
 print(scores)

我的困惑在于,我的方法是否正在进行TF-IDF转换和进行5倍交叉验证是否正确,或者是否有必要使用训练数据重建TF-IDF矢量模型空间然后转换为TF-IDF向量包含火车和测试数据?如下:

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)
for train_index, test_index in skf.split(data_x, data_y):
    x_train, x_test = data_x[train_index], data_x[test_index]
    y_train, y_test = data_y[train_index], data_y[test_index]

    tfidf = TfidfVectorizer()
    x_train = tfidf.fit_transform(x_train)
    x_test = tfidf.transform(x_test)

        clf = SVC(kernel='linear')
        clf.fit(x_train, y_train)
        y_pred = clf.predict(x_test)
        score = accuracy_score(y_test, y_pred)
        print(score)

1 个答案:

答案 0 :(得分:0)

您用来构建StratifiedKFold的{​​{1}}方法是正确的方法,这样做可以确保仅根据训练数据集生成要素。

如果您考虑在整个数据集上构建TfidfVectorizer(),那么即使我们没有明确地提供测试数据集,它也会将测试数据集泄漏到模型中。当包含测试文档时,诸如词汇量,词汇中每个术语的IDF值之类的参数会大大不同。

更简单的方法可能是使用管道和cross_validate。

使用它!

TfidfVectorizer()

注意:仅对测试数据执行from sklearn.pipeline import make_pipeline clf = make_pipeline(TfidfVectorizer(), svm.SVC(kernel='linear')) scores = cross_validate(clf, data_x, data_y, scoring=['accuracy'], cv=5, return_train_score=False) print(scores) 并没有用。我们必须对cross_validate数据集进行处理。