训练/测试分割物体检测

时间:2020-03-05 16:31:29

标签: machine-learning scikit-learn computer-vision object-detection train-test-split

是否存在任何脚本/函数来拆分数据,以计算每个图像中类出现的数量并平衡它们? 我已经尝试过sklearn train_test_split了:

data = pd.read_csv('train_labels.csv')
data.head()

类是我要预测的,在一张图像上我可以有0..n个矩形,每个矩形都有一个类。

enter image description here

data = data.drop_duplicates(subset="filename")
y = data['class']
X = data.drop('class',axis = 1)
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2)

但是,当我删除文件名中的重复项时,我失去了信息,也许我发送了文件进行培训或与其他许多类一起进行测试,但是如果不删除它们,我可以在培训中进行文件测试。

感谢您的帮助。

1 个答案:

答案 0 :(得分:1)

它们是scikit-multilearn库,将有助于拆分多标签数据。 数据格式应该像

df = pd.DataFrame({"train": [1,2,3,4,5,6,7,8],
                      "y1": [1,1,0,0,1,1,0,0],
                      "y2": [0,0,0,0,1,1,1,1]})

安装:pip install scikit-multilearn

文档:http://scikit.ml/stratification.html

代码:

from skmultilearn.model_selection import iterative_train_test_split
X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size = 0.5)
X_train = array([[1],[4],[6],[7]])
y_train = array([[1, 0],[0, 0],[1, 1],[0, 1]])
X_test =  array([[2],[3],[5],[8]])
y_test =  array([[1, 0],[0, 0],[1, 1],[0, 1]]))