PyTorch,根据数据列中的标签选择批次

时间:2021-02-16 16:51:54

标签: python pytorch pytorch-dataloader

我有一个这样的数据集:

<头>
索引 标签 功能1 功能2 目标
1 标签1 1.4342 88.4554 0.5365
2 标签1 2.5656 54.5466 0.1263
3 标签2 5.4561 845.556 0.8613
4 标签3 6.5546 8.52545 0.7864
5 标签3 8.4566 945.456 0.4646

每个标签中的条目数量并不总是相同的。

我的目标是仅加载具有特定标签或标签的数据,以便我只获取一个小批量的 tag1 中的条目,然后获取另一个小批量的 tag2 中的条目,如果我设置了 batch_size=1。或者例如 tag1tag2 如果我设置了 batch_size=2

到目前为止,我的代码完全忽略了 tag 标签,只是随机选择批次。

我构建了这样的数据集:

# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

我的加载器(通常)如下所示:

def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

然后我像这样训练:

for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

还有train_batch

def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss

1 个答案:

答案 0 :(得分:2)

一个简单的数据集,它大致实现了我所知道的你正在寻找的特征。

class CustomDataset(data.Dataset):
    def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
       # self.tags should be a tensor in k-hot encoding form so a 2D tensor, 
       self.tags = tagsTrain
       self.x = featuresTrain
       self.y = targetsTrain
       self.unique_tagsets = None
       self.sample_equally = sample_equally

       # self.active tags is a 1D k-hot encoding vector
       self.active_tags = self.get_random_tag_set()
       
    
    def get_random_tag_set(self):
        # gets all unique sets of tags and returns one randomly
        if self.unique_tagsets is None:
             self.unique_tagsets = self.tags.unique(dim = 0)
        if self.sample_equally:
             rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
             return self.unique_tagsets[rand_idx]
        else:
            rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
            return self.tags[rand_idx]

    def set_tags(self,tags):
       # specifies the set of tags that must be present for a datum to be selected
        self.active_tags = tags

    def __getitem__(self,index):
        # get all indices of elements with self.active_tags
        indices = torch.where(self.tags == self.active_tags)[0]

        # we select an index based on the indices of the elements that have the tag set
        idx = indices[index % len(indices)]

        item = self.x[idx], self.y[idx]
        return item

    def __len__(self):
        return len(self.y)

该数据集随机选择一组标签。然后,每次调用 __getitem__() 时,它都会使用指定的索引从具有标签集的数据元素中进行选择。您可以在每个小批量之后调用 set_tags()get_random_tag_set() 然后调用 set_tags(),或者您想要更改标签集的频率,或者您可以自己手动指定标签集。数据集继承自 torch.data.Dataset,因此您应该可以不加修改地将 if 与 torch.data.Dataloader 一起使用。

您可以使用 sample_equally 指定是要根据每组标签的流行程度对每个标签集进行采样,还是要对所有标签集进行平均采样,而不管该集有多少元素。< /p>

简而言之,这个数据集的边缘有点粗糙,但应该允许您使用相同的标签集对所有批次进行采样。主要缺点是每个元素可能每批次被采样多次。

对于初始编码,假设开始每个数据示例都有一个标签列表,所以 tags 是一个列表列表,每个子列表包含标签。以下代码会将其转换为 k-hot 编码,因此您只需:

def to_k_hot(tags):
  all_tags = []
  for ex in tags:
    for tag in ex:
        all_tags.append(tag)
  unique_tags = list(set(all_tags)) # remove duplicates

  tagsTrain = torch.zeros([len(tags),len(unique_tags)]): 
  for i in range(len(tags)): # index through all examples
    for j in range(len(unique_tags)): # index through all unique_tags
        if unique_tags[j] in tags[i]:
             tagsTrain[i,j] = 1

  return tagsTrain

举个例子,假设你有一个数据集的以下标签:

tags = [ [tag1],
         [tag1,tag2],
         [tag3],
         [tag2],
         [],
         [tag1,tag2,tag3] ]

调用 to_k_hot(tags) 会返回:

tensor([1,0,0],
       [1,1,0],
       [0,0,1],
       [0,1,0],
       [0,0,0],
       [1,1,1]])
相关问题