分割火车和测试数据以进行CNN图像分类

时间:2020-06-08 16:04:28

标签: machine-learning computer-vision

所以我对图像有二进制分类问题,对于类a和b有平衡的数据集。

我每个班级有307张图片。我想问一下,当我拆分训练和测试数据集时,训练和测试是否也应针对每个班级进行平衡?或任何分割数据集的方法

1 个答案:

答案 0 :(得分:1)

您可以使用sklearn.model_selection.StratifiedShuffleSplit,它使用分层随机抽样,比例随机抽样或配额随机抽样。这样可以使分布更好。

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html

import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
# dummy dataset
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 0, 1, 1, 1])
sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
sss.get_n_splits(X, y)

print(sss)

for train_index, test_index in sss.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

对于CNN来说,307可能很低,您也可以使用数据增强来增加样本。

https://github.com/mdbloice/Augmentor