在pytorch中,DataLoader会将数据集分成具有固定大小的批处理,并带有混洗等其他选项,然后可以循环遍历。
但是如果我需要增加批量大小,例如前10个批量大小为50的批处理,接下来5个批量大小为100的批处理,依此类推,那么最好的方法是什么?
我试图分裂张量,然后连接它们:
#10x50 + 5*100
originalTensor = torch.randn(1000, 80)
split1=torch.split(originalTensor, 500, dim=0)
split2=torch.split(list(split1)[0], 100, dim=0)
此后是否有一种方法可以将级联张量传递到dataLoader中,或者有其他任何方法直接将级联张量转换为生成器(这可能会丢失改组和其他功能)?
答案 0 :(得分:1)
我认为您可以通过向您的batch_sampler
提供一个非默认DataLoader
来做到这一点。
例如:
class VaryingSizeBatchSampler(Sampler):
r"""Wraps another sampler to yield a varying-size mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size_fn (function): Size of current mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
"""
def __init__(self, sampler, batch_size_fn, drop_last):
if not isinstance(sampler, Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
self.sampler = sampler
self.batch_size_fn = batch_size_fn
self.drop_last = drop_last
self.batch_counter = 0
def __iter__(self):
batch = []
cur_batch_size = self.batch_size_fn(self.batch_counter) # get current batch size
for idx in self.sampler:
batch.append(idx)
if len(batch) == cur_batch_size:
yield batch
self.batch_counter += 1
cur_batch_size = self.batch_size_fn(self.batch_counter) # get current batch size
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
raise NotImplementedError('You need to implement it yourself!')