在scikit中使用火车测试时获取指数

时间:2016-02-25 08:51:39

标签: python-2.7 scikit-learn

为了将我的数据分成火车和测试数据,我正在使用

sklearn.cross_validation.train_test_split功能。

当我将数据和标签作为列表列表提供给此函数时,它会在两个单独的列表中返回列车和测试数据。

我想从原始数据列表中获取列车的索引和测试数据元素。

任何人都可以帮我解决这个问题吗?

提前致谢

2 个答案:

答案 0 :(得分:19)

您可以提供索引向量作为附加参数。使用sklearn中的示例:

import numpy as np
from sklearn.cross_validation import train_test_split
X, y,indices = (0.1*np.arange(10)).reshape((5, 2)),range(10,15),range(5)
X_train, X_test, y_train, y_test,indices_train,indices_test = train_test_split(X, y,indices, test_size=0.33, random_state=42)
indices_train,indices_test
#([2, 0, 3], [1, 4])

答案 1 :(得分:0)

尝试以下解决方案(取决于您是否有不平衡):

NUM_ROWS = train.shape[0]
TEST_SIZE = 0.3
indices = np.arange(NUM_ROWS)

# usual train-val split
train_idx, val_idx = train_test_split(indices, test_size=TEST_SIZE, train_size=None)

# stratified train-val split as per Response's proportion (if imbalance)
strat_train_idx, strat_val_idx = train_test_split(indices, test_size=TEST_SIZE, stratify=y)