如何使用pytorch将多标签分类上下文下的标签列表覆盖为一键编码?

时间:2019-05-14 05:15:06

标签: pytorch one-hot-encoding multilabel-classification

我有一个批处理数据列表,每个样品带有多个标签。那么如何以一键编码的方式将其隐藏到torch.Tensor中呢? 例如,使用batch_size=5class_num=6

label =[
[1,2,3],
[4,6],
[1],
[1,4,5],
[4]
]

如何在pytorch中使其成为一键编码?

label_tensor=tensor([
[1,1,1,0,0,0],
[0,0,0,1,0,1],
[1,0,0,0,0,0],
[1,0,0,1,1,0],
[0,0,0,1,0,0]
])

1 个答案:

答案 0 :(得分:0)

如果批量大小可以从len(labels)中得出:

def to_onehot(labels, n_categories, dtype=torch.float32):
    batch_size = len(labels)
    one_hot_labels = torch.zeros(size=(batch_size, n_categories), dtype=dtype)
    for i, label in enumerate(labels):
        # Subtract 1 from each LongTensor because your
        # indexing starts at 1 and tensor indexing starts at 0
        label = torch.LongTensor(label) - 1
        one_hot_labels[i] = one_hot_labels[i].scatter_(dim=0, index=label, value=1.)
    return one_hot_labels

,您有6个类别,希望输出为整数张量

to_onehot(labels, n_categories=6, dtype=torch.int64)
tensor([[1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 1],
        [1, 0, 0, 0, 0, 0],
        [1, 0, 0, 1, 1, 0],
        [0, 0, 0, 1, 0, 0]])

如果您以后要使用标签平滑,混合或类似的方法,我会坚持使用torch.float32