是否存在任何脚本/函数来拆分数据,以计算每个图像中类出现的数量并平衡它们? 我已经尝试过sklearn train_test_split了:
data = pd.read_csv('train_labels.csv')
data.head()
类是我要预测的,在一张图像上我可以有0..n个矩形,每个矩形都有一个类。
data = data.drop_duplicates(subset="filename")
y = data['class']
X = data.drop('class',axis = 1)
X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2)
但是,当我删除文件名中的重复项时,我失去了信息,也许我发送了文件进行培训或与其他许多类一起进行测试,但是如果不删除它们,我可以在培训中进行文件测试。
感谢您的帮助。
答案 0 :(得分:1)
它们是scikit-multilearn库,将有助于拆分多标签数据。 数据格式应该像
df = pd.DataFrame({"train": [1,2,3,4,5,6,7,8],
"y1": [1,1,0,0,1,1,0,0],
"y2": [0,0,0,0,1,1,1,1]})
安装:pip install scikit-multilearn
文档:http://scikit.ml/stratification.html
代码:
from skmultilearn.model_selection import iterative_train_test_split
X_train, y_train, X_test, y_test = iterative_train_test_split(X, y, test_size = 0.5)
X_train = array([[1],[4],[6],[7]])
y_train = array([[1, 0],[0, 0],[1, 1],[0, 1]])
X_test = array([[2],[3],[5],[8]])
y_test = array([[1, 0],[0, 0],[1, 1],[0, 1]]))