通过替换sklearn.cross_validation从sklearn.model_selection导入StratifiedShuffleSplit时,应该对参数进行哪些更改

时间:2019-03-24 18:21:17

标签: scikit-learn cross-validation

我试图运行python3代码进行隔离的语音识别,在那里我得到了一个DeprecationWarning供使用:

from sklearn.cross_validation import StratifiedShuffleSplit

为了消除此警告,我只是从StratifiedShuffleSplit导入了sklearn.model_selection而不是sklearn.cross_validation,并且在运行代码后得到了:

  

TypeError:“ StratifiedShuffleSplit”对象不可迭代

也许是因为

class sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=10, test_size=0.1, train_size=None, random_state=None)

y是一个数组。

在:

class sklearn.cross_validation.StratifiedShuffleSplit(y, n_iter=10, test_size=0.1, train_size=None, random_state=None)

没有数组:

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(all_labels, test_size=0.1, random_state=0)

for n,i in enumerate(all_obs):
    all_obs[n] /= all_obs[n].sum(axis=0)

for train_index, test_index in sss:
    X_train, X_test = all_obs[train_index, ...], all_obs[test_index, ...]
    y_train, y_test = all_labels[train_index], all_labels[test_index]
ys = set(all_labels)
ms = [gmmhmm(7) for y in ys]

如何替换all_labels,因为它是根据sklearn.cross_validation的数组,但是sklearn.model_selection不接受数组参数。

1 个答案:

答案 0 :(得分:0)

两者之间的区别是

  • sklearn.model_selection.StratifiedShuffleSplit是一个交叉验证器

  • sklearn.cross_validation.StratifiedShuffleSplit是交叉验证器迭代器

因此示例中的正确用法是

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(test_size=0.1, random_state=0)

for n,i in enumerate(all_obs):
    all_obs[n] /= all_obs[n].sum(axis=0)

for train_index, test_index in sss.split(all_obs, all_labels):
     print(train_index, test_index)

阅读sklearn.model_selection.StratifiedShuffleSplit

文档中的示例可能会有所帮助