我正在尝试在PyTorch中训练一个深度学习模型的图像,这些图像已经被特定尺寸所取代。我想用迷你批次训练我的模型,但是迷你批量大小并没有巧妙地划分每个桶中的例子数量。
我在a previous post中看到的一个解决方案是用额外的空格填充图像(无论是在运行中还是在训练开始时一次性完成),但我不想这样做。相反,我希望在培训期间允许批量大小灵活。
具体来说,如果N
是存储桶中的图像数量且B
是批量大小,那么对于该存储桶我希望N // B
批次B
除此之外,除了N
和N // B + 1
批次。最后一批可以有少于B
个例子。
作为一个例子,假设我有索引[0,1,...,19],包括在内,我想使用批量大小为3.
索引[0,9]对应于桶0中的图像(形状(C,W1,H1))
索引[10,19]对应于桶1中的图像(形状(C,W2,H2))
(所有图像的通道深度相同)。然后,可接受的索引分区将是
batches = [
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18],
[19]
]
我更愿意分别处理分别为9和19的图像,因为它们具有不同的尺寸。
通过PyTorch的文档,我找到了生成小批量索引列表的BatchSampler
类。我创建了一个自定义Sampler
类,它模拟上述索引的分区。如果它有帮助,这是我的实现:
class CustomSampler(Sampler):
def __init__(self, dataset, batch_size):
self.batch_size = batch_size
self.buckets = self._get_buckets(dataset)
self.num_examples = len(dataset)
def __iter__(self):
batch = []
# Process buckets in random order
dims = random.sample(list(self.buckets), len(self.buckets))
for dim in dims:
# Process images in buckets in random order
bucket = self.buckets[dim]
bucket = random.sample(bucket, len(bucket))
for idx in bucket:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
# Yield half-full batch before moving to next bucket
if len(batch) > 0:
yield batch
batch = []
def __len__(self):
return self.num_examples
def _get_buckets(self, dataset):
buckets = defaultdict(list)
for i in range(len(dataset)):
img, _ = dataset[i]
dims = img.shape
buckets[dims].append(i)
return buckets
但是,当我使用自定义Sampler
类时,我会生成以下错误:
Traceback (most recent call last):
File "sampler.py", line 143, in <module>
for i, batch in enumerate(dataloader):
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 263, in __next__
indices = next(self.sample_iter) # may raise StopIteration
File "/home/roflcakzorz/anaconda3/lib/python3.6/site-packages/torch/utils/data/sampler.py", line 139, in __iter__
batch.append(int(idx))
TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'
DataLoader
类似乎期望传递索引,而不是索引列表。
我是否应该为此任务使用自定义Sampler
课程?我还考虑将自定义collate_fn
传递给DataLoader
,但是使用这种方法我不相信我可以控制允许哪些索引在同一个小批量中。任何指导都将不胜感激。
答案 0 :(得分:0)
每个样本是否有2个网络(必须修复cnn内核大小)。如果是,只需将上面的custom_sampler
传递给DataLoader类的batch_sampler args即可。这样可以解决问题。
答案 1 :(得分:0)
您好,因为每个批次都应包含相同尺寸的图像,所以您的CustomSampler
可以正常工作,因此需要将其作为关键字mx.gluon.data.DataLoader
传递给batch_sampler
。但是,如文档中所述,请记住以下几点:
“如果指定了
shuffle
,请不要指定sampler
,last_batch
和batch_sampler
”