使用torch.utils TensorDataset和DataLoader

时间:2019-12-16 17:34:50

标签: python tensorflow torch dataloader

我正在尝试使用data_utils.TensorDataset和data_utils.DataLoader。我的原始数据xx是熊猫数据框。我使用get_pinned_object(xx),因为xx是一个大数据集。

import torch.utils.data as data_utils  
def train_model(space, reporter):

    model = SKPipeline([('pipe_pre', pipe_pre), ('pipe_model', pipe_model)])
    optimizer = SGD()
    dataset = get_pinned_object(xx)

    dataset = data_utils.TensorDataset(dataset.drop(['target'],axis=1), dataset.target)
    dataloader =  data_utils.DataLoader(dataset, batch_size=50, shuffle=True, num_workers=2)
    for i, (data, target) in enumerate(dataloader):
        accuracy = cross_validate(model, data, target, scoring='roc_auc', cv=5, 
        return_train_score=False,n_jobs = -1)
        reporter(mean_accuracy=accuracy) 

如果我尝试运行此命令,则会出现错误assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)dataset.drop(['target'],axis=1).shape(10000, 374),而dataset.target.shape(10000,)。这不是正确的形状吗?我在做什么错了?

0 个答案:

没有答案