我之前使用过sklearn.cross_validation import StratifiedShuffleSplit
。但是它没有用。然后我尝试了from sklearn.model_selection import StratifiedShuffleSplit
,它抛出以下错误:
import os
import numpy as np
def _load_cifar10_batch(file):
import pickle as cPickle
fo = open(file, 'rb')
dict = cPickle.load(fo, encoding='latin1')
fo.close()
return dict['data'].reshape(-1, 32, 32, 3), dict['labels'] # reshaping the data to 32 x 32 x 3
print('Loading...')
batch_fns = [os.path.join("./", 'cifar-10-batches-py', 'data_batch_' + str(i)) for i in range(1, 6)]
data_batches = [_load_cifar10_batch(fn) for fn in batch_fns]
data_all = np.vstack([data_batches[i][0] for i in range(len(data_batches))]).astype('float')
labels_all = np.vstack([data_batches[i][1] for i in range(len(data_batches))]).flatten()
#Splitting the whole training set into 92:8
seed=7
from sklearn.model_selection import StratifiedShuffleSplit
data_split = StratifiedShuffleSplit(labels_all,1, test_size=0.08,random_state=seed) #creating data_split object with 8% test size
for train_index, test_index in data_split:
split_data_92, split_data_8 = data_all[train_index], data_all[test_index]
split_label_92, split_label_8 = labels_all[train_index], labels_all[test_index]
TypeError Traceback (most recent call last)
<ipython-input-29-a61d57ed4f74> in <module>
6
7
----> 8 data_split = StratifiedShuffleSplit(labels_all,1, test_size=0.08,random_state=seed) #creating data_split object with 8% test size
9
10 for train_index, test_index in data_split:
TypeError: __init__() got multiple values for argument 'test_size'