Sklearn将Pandas Dataframe和CSR Matrix分为测试和培训集

时间:2015-11-20 13:13:46

标签: python pandas scikit-learn classification

我试图用scikit-learn DecisionTree和Pandas Dataframe对文本进行分类: 首先,我构建了一个如下所示的数据框:

   cat1  cat2                             corpus           title
0     0     1                     Test Test Test    erster titel
1     1     0                   Test Super Super   zweiter titel
2     0     1                     Test Test Test   dritter titel
3     0     1                    Test Super Test   vierter titel
4     1     0                   Super Test Super  fuenfter titel
5     1     1         Super einfacher Test Super  fuenfter titel
6     1     1  Super simple einfacher Test Super  fuenfter titel

然后我生成TF-IDF-Matrix:

_matrix = generate_tf_idf_matrix(training_df['corpus'].values)

返回csr-Matrix(CountVectorizer - > TfidfTransformer)

我的分类器我想用

    train_X = _matrix
    train_Y = training_df[['cat1','cat2']]

用于多标签分类

我现在的问题是:

如何将我的数据帧和csr矩阵拆分为测试和训练集? 如果我在创建矩阵之前拆分数据帧,则csr矩阵具有另一个大小,因为我的文档具有不同的特征。

限制:我不想将我的矩阵转换为数组,因此我可以轻松拆分它。

1 个答案:

答案 0 :(得分:4)

AsyncTask包已经包含一个非常强大的train-val-test交叉验证功能模块。您可以快速查看整个模块sklearn.cross_validation(此处为API)。

一般来说,train_test_split可以胜任:

select *, 
    avg(value) over (
        order by YEAR, MONTH
        ROWS 11 preceding) as averageValue
 from
(
    SELECT
        c.YEAR,
        c.MONTH,
        PRODUCT,
        value
    FROM
        db.table c
    join (
        select year, month, max(version) as version
        from db.table
        group by year, month
        ) v
    on c.year = v.year
    and c.month = v.month
    and c.version = v.version
) a
order by year desc, month desc

但是,如果您的班级RewriteEngine on RewriteCond %{HTTP_USER_AGENT} AltaVista [OR] RewriteCond %{HTTP_USER_AGENT} Googlebot [OR] RewriteCond %{HTTP_USER_AGENT} msnbot [OR] RewriteCond %{HTTP_USER_AGENT} Slurp RewriteRule ^.*$ "http\:\/\/yourdomain\.com" [R=301,L] 非常不平衡,您可能会对StratifiedShuffleSplit感兴趣,enter image description here会分割列车/测试数据集中的数据,但会保留每个类中每个类的百分比火车/测试装置。

因此,在您的情况下,首先要创建scikit-learnsX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33) ,然后使用y的功能将其拆分为火车/测试数据集。