我试图创建一个用于元学习的数据加载器,但发现我的代码非常慢,我不知道为什么。之所以这样做,是因为元学习中使用的是一组数据集(因此我需要它们的数据加载器)。
我想知道是否是因为我有一个整理函数生成数据加载程序。
这是整理函数,可生成数据加载器(并接收所有数据集):
class GetMetaBatch_NK_WayClassTask:
def __init__(self, meta_batch_size, n_classes, k_shot, k_eval, shuffle=True, pin_memory=True, original=False, flatten=True):
self.meta_batch_size = meta_batch_size
self.n_classes = n_classes
self.k_shot = k_shot
self.k_eval = k_eval
self.shuffle = shuffle
self.pin_memory = pin_memory
self.original = original
self.flatten = flatten
def __call__(self, all_datasets, verbose=False):
NUM_WORKERS = 0 # no need to change
get_data_loader = lambda data_set: iter(data.DataLoader(data_set, batch_size=self.k_shot+self.k_eval, shuffle=self.shuffle, num_workers=NUM_WORKERS, pin_memory=self.pin_memory))
#assert( len(meta_set) == self.meta_batch_size*self.n_classes )
# generate M N,K-way classification tasks
batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y = [], [], [], []
for m in range(self.meta_batch_size):
n_indices = random.sample(range(0,len(all_datasets)), self.n_classes)
# create N-way, K-shot task instance
spt_x, spt_y, qry_x, qry_y = [], [], [], []
for i,n in enumerate(n_indices):
data_set_n = all_datasets[n]
dataset_loader_n = get_data_loader(data_set_n) # get data set for class n
data_x_n, data_y_n = next(dataset_loader_n) # get all data from current class
spt_x_n, qry_x_n = data_x_n[:self.k_shot], data_x_n[self.k_shot:] # [K, CHW], [K_eval, CHW]
# get labels
if self.original:
#spt_y_n = torch.tensor([n]).repeat(self.k_shot)
#qry_y_n = torch.tensor([n]).repeat(self.k_eval)
spt_y_n, qry_y_n = data_y_n[:self.k_shot], data_y_n[self.k_shot:]
else:
spt_y_n = torch.tensor([i]).repeat(self.k_shot)
qry_y_n = torch.tensor([i]).repeat(self.k_eval)
# form K-shot task for current label n
spt_x.append(spt_x_n); spt_y.append(spt_y_n) # array length N with tensors size [K, CHW]
qry_x.append(qry_x_n); qry_y.append(qry_y_n) # array length N with tensors size [K, CHW]
# form N-way, K-shot task with tensor size [N,W, CHW]
spt_x, spt_y, qry_x, qry_y = torch.stack(spt_x), torch.stack(spt_y), torch.stack(qry_x), torch.stack(qry_y)
# form N-way, K-shot task with tensor size [N*W, CHW]
if verbose:
print(f'spt_x.size() = {spt_x.size()}')
print(f'spt_y.size() = {spt_y.size()}')
print(f'qry_x.size() = {qry_x.size()}')
print(f'spt_y.size() = {qry_y.size()}')
print()
if self.flatten:
CHW = qry_x.shape[-3:]
spt_x, spt_y, qry_x, qry_y = spt_x.reshape(-1, *CHW), spt_y.reshape(-1), qry_x.reshape(-1, *CHW), qry_y.reshape(-1)
## append to N-way, K-shot task to meta-batch of tasks
batch_spt_x.append(spt_x); batch_spt_y.append(spt_y)
batch_qry_x.append(qry_x); batch_qry_y.append(qry_y)
## get a meta-set of M N-way, K-way classification tasks [M,K*N,C,H,W]
batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y = torch.stack(batch_spt_x), torch.stack(batch_spt_y), torch.stack(batch_qry_x), torch.stack(batch_qry_y)
return batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y
传递到此处的另一个数据加载器:
def get_meta_set_loader(meta_set, meta_batch_size, n_episodes, n_classes, k_shot, k_eval, pin_mem=True, n_workers=4):
"""[summary]
Args:
meta_set ([type]): the meta-set
meta_batch_size ([type]): [description]
n_classes ([type]): [description]
pin_mem (bool, optional): [Since returning cuda tensors in dataloaders is not recommended due to cuda subties with multithreading, instead set pin=True for fast transfering of the data to cuda]. Defaults to True.
n_workers (int, optional): [description]. Defaults to 4.
Returns:
[type]: [description]
"""
if n_classes > len(meta_set):
raise ValueError(f'You really want a N larger than the # classes in the meta-set? n_classes, len(meta_set = {n_classes, len(meta_set)}')
collator_nk_way = GetMetaBatch_NK_WayClassTask(meta_batch_size, n_classes, k_shot, k_eval)
episodic_sampler = EpisodicSampler(total_classes=len(meta_set), n_episodes=n_episodes)
episodic_metaloader = data.DataLoader(
meta_set,
num_workers=n_workers,
pin_memory=pin_mem, # to make moving to cuda more efficient
collate_fn=collator_nk_way, # does the collecting to return M N,K-shot task
batch_sampler=episodic_sampler # for keeping track of the episode
)
return episodic_metaloader
(将生成一个较小的示例)
相关:
答案 0 :(得分:1)
从概念上讲,即使一个在另一个内部,pytorch数据加载器也应该没有问题。调试问题的一种方法是使用line_profiler
软件包来更好地了解减速的发生位置。
如果使用line_profiler
后仍无法解决问题,请使用事件探查器的输出更新您的问题,以帮助我们了解可能出了问题。允许探查器运行一段时间,以收集有关数据加载器执行情况的足够统计信息。 @profile
装饰器也适用于函数和类函数,因此应适用于数据加载器函数。