自定义交叉验证拆分sklearn

时间:2014-06-06 09:19:23

标签: python validation machine-learning scikit-learn cross-validation

我正在尝试拆分交叉验证的数据集和sklearn中的GridSearch。 我想定义自己的分割,但GridSearch只采用内置的交叉验证方法。

但是,我无法使用内置的交叉验证方法,因为我需要将某些示例组放在同一个视图中。 所以,如果我有例子: [A1,A2,A3,A4,A5,B1,B2,B3,C1,C2,C3,C4,......,Z1,Z2,Z3]

我想进行交叉验证,以便每个组[A,B,C ......]中的示例仅存在于一个部分中。

即K1包含[D,E,G,J,K ...],K2包含[A,C,L,M,...],K3包含[B,F,I,...]等

2 个答案:

答案 0 :(得分:11)

通常可以使用sklearn.cross_validation.LeaveOneLabelOut完成此类事情。您只需构建一个对您的组进行编码的标签向量。即,K1中的所有样本都会带有标签1K2中的所有样本都会带有标签2,依此类推。

这是一个带有虚假数据的完全可运行的示例。重要的行是创建cv对象的行,以及对cross_val_score

的调用
import numpy as np

n_features = 10

# Make some data
A = np.random.randn(3, n_features)
B = np.random.randn(5, n_features)
C = np.random.randn(4, n_features)
D = np.random.randn(7, n_features)
E = np.random.randn(9, n_features)

# Group it
K1 = np.concatenate([A, B])
K2 = np.concatenate([C, D])
K3 = E

data = np.concatenate([K1, K2, K3])

# Make some dummy prediction target
target = np.random.randn(len(data)) > 0

# Make the corresponding labels
labels = np.concatenate([[i] * len(K) for i, K in enumerate([K1, K2, K3])])

from sklearn.cross_validation import LeaveOneLabelOut, cross_val_score

cv = LeaveOneLabelOut(labels)

# Use some classifier in crossvalidation on data
from sklearn.linear_model import LogisticRegression

lr = LogisticRegression()
scores = cross_val_score(lr, data, target, cv=cv)

但是,您可能会遇到想要完全手动定义折叠的情况。在这种情况下,您需要创建一个iterable(例如一个list)对夫妻(train, test),通过索引指示要进入您的火车的样本和每个折叠的测试集。我们来看看:

# create train and test folds from our labels:
cv_by_hand = [(np.where(labels != label)[0], np.where(labels == label)[0])
               for label in np.unique(labels)]

# We check this against our existing cv by converting the latter to a list
cv_to_list = list(cv)

print cv_by_hand
print cv_to_list

# Check equality
for (train1, test1), (train2, test2) in zip(cv_by_hand, cv_to_list):
    assert (train1 == train2).all() and (test1 == test2).all()

# Use the created cv_by_hand in cross validation
scores2 = cross_val_score(lr, data, target, cv=cv_by_hand)


# assert equality again
assert (scores == scores2).all()

答案 1 :(得分:0)

我知道这个问题已经很老了,但我遇到了同样的问题。看起来很快会有一个贡献可以让你这样做:

https://github.com/scikit-learn/scikit-learn/pull/4583

相关问题