使用Pytorch进行分层交叉验证

时间:2018-11-17 20:35:34

标签: pytorch cross-validation

我的目标是使用神经网络进行二进制分类。 问题是数据集不平衡,我有1类的90%和0类的10个。 为了解决这个问题,我想使用分层交叉验证。

我正在与Pytorch一起工作的问题,我找不到任何示例,文档也没有提供它,而且我是学生,对于神经网络来说还很新。

有人可以帮忙吗? 谢谢!

2 个答案:

答案 0 :(得分:1)

我发现最简单的方法是在将数据传递到Pytorch DatasetDataLoader之前对分层进行分层。这样一来,您就不必将所有代码移植到skorch上,而这会破坏与某些集群计算框架的兼容性。

答案 1 :(得分:0)

看看skorch。这是一个scikit-learn兼容的神经网络库,其中包装了PyTorch。它具有用于交叉验证的功能CVSplit,也可以使用sklearn。 从文档中:

net = NeuralNetClassifier(
   module=MyModule,
   train_split=None,
)
from sklearn.model_selection import cross_val_predict
y_pred = cross_val_predict(net, X, y, cv=5)