如何在pytorch中增加批量大小

时间:2019-09-22 05:00:39

标签: pytorch

在pytorch中,DataLoader会将数据集分成具有固定大小的批处理,并带有混洗等其他选项,然后可以循环遍历。

但是如果我需要增加批量大小,例如前10个批量大小为50的批处理,接下来5个批量大小为100的批处理,依此类推,那么最好的方法是什么?

我试图分裂张量,然后连接它们:

#10x50 + 5*100
originalTensor = torch.randn(1000, 80)
split1=torch.split(originalTensor, 500, dim=0)
split2=torch.split(list(split1)[0], 100, dim=0)

此后是否有一种方法可以将级联张量传递到dataLoader中,或者有其他任何方法直接将级联张量转换为生成器(这可能会丢失改组和其他功能)?

1 个答案:

答案 0 :(得分:1)

我认为您可以通过向您的batch_sampler提供一个非默认DataLoader来做到这一点。
例如:

class VaryingSizeBatchSampler(Sampler):
    r"""Wraps another sampler to yield a varying-size mini-batch of indices.

    Args:
        sampler (Sampler): Base sampler.
        batch_size_fn (function): Size of current mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    """

    def __init__(self, sampler, batch_size_fn, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        self.sampler = sampler
        self.batch_size_fn = batch_size_fn
        self.drop_last = drop_last
        self.batch_counter = 0

    def __iter__(self):
        batch = []
        cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == cur_batch_size:
                yield batch
                self.batch_counter += 1
                cur_batch_size = self.batch_size_fn(self.batch_counter)  # get current batch size                
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch            

    def __len__(self):
        raise NotImplementedError('You need to implement it yourself!')