使用sklearn.model_selection拆分不平衡的数据集

时间:2019-05-07 13:48:02

标签: python machine-learning scikit-learn dataset

我正在使用以下代码将我的数据集拆分为训练/验证/测试集。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = 
        train_test_split(X_data, y_data, test_size=0.3, random_state=42)

X_test, X_val, y_test, y_val = 
        train_test_split(X_test, y_test, test_size=0.5, random_state=42)

问题是我的数据集确实不平衡。有些类别有500个样本,而有些类别有70个样本。在这种情况下,这种分割方法准确吗?是随机抽样还是sklearn使用seome方法使所有集中的数据分布保持相同?

1 个答案:

答案 0 :(得分:1)

您应该使用stratify选项(请参阅docs):

X_train, X_test, y_train, y_test = 
        train_test_split(X_data, y_data, test_size=0.3, random_state=42, stratify=y_data)

X_test, X_val, y_test, y_val = 
        train_test_split(X_test, y_test, test_size=0.5, random_state=42, stratify=y_test)