如何使用WeightedRandomSampler平衡PyTorch中的不平衡数据?

时间:2019-01-29 08:32:51

标签: python machine-learning pytorch

我有2类问题,并且我的数据不平衡。 0类有232550个样本,1类有13498个样本。 PyTorch文档和互联网告诉我为我的DataLoader使用类WeightedRandomSampler。

我曾经尝试过使用WeightedRandomSampler,但是我一直遇到错误。

    trainratio = np.bincount(trainset.labels) #trainset.labels is a list of 
    float [0,1,0,0,0,...] 
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
                                 num_samples=len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
                                       shuffle=False)

我打印出的某些尺寸:

train_weights = tensor([4.3002e-06, 4.3002e-06, 4.3002e-06,  ..., 
4.3002e-06, 4.3002e-06, 4.3002e-06])

train_weights shape=  torch.Size([246048])

我看不到为什么出现此错误:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.weights = torch.tensor(weights, dtype=torch.double)

我尝试了其他类似的解决方法,但到目前为止,所有尝试均会产生一些错误。 我应该如何实现这一点来平衡训练,验证和测试数据?

1 个答案:

答案 0 :(得分:0)

因此,这显然是内部警告,而不是错误。根据PyTorch的家伙说,我可以继续编码,而不必担心警告消息。