如何在PyTorch中平衡(过采样)不平衡数据(使用WeightedRandomSampler)?

时间:2019-01-29 06:49:58

标签: python machine-learning pytorch data-cleaning

我有2类问题,并且我的数据高度不平衡。我有来自一堂课的232550个样本和来自第二堂课的13498。 PyTorch文档和互联网告诉我为我的DataLoader使用类WeightedRandomSampler。

我曾经尝试过使用WeightedRandomSampler,但是我一直遇到错误。

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

我看不到为什么在初始化WeightedRandomSampler类时收到此错误?

我尝试了其他类似的解决方法,但到目前为止,所有尝试均会产生一些错误。 我应该如何实现这一点来平衡训练,验证和测试数据?

当前出现此错误:

  

train__sampleweights = train_weights [trainset.labels] ValueError:也是   许多维度“ str”

1 个答案:

答案 0 :(得分:0)

问题出在trainset.labels的类型中 要解决该错误,可以将trainset.labels转换为float