我有一个批处理数据列表,每个样品带有多个标签。那么如何以一键编码的方式将其隐藏到torch.Tensor中呢?
例如,使用batch_size=5
和class_num=6
,
label =[
[1,2,3],
[4,6],
[1],
[1,4,5],
[4]
]
如何在pytorch中使其成为一键编码?
label_tensor=tensor([
[1,1,1,0,0,0],
[0,0,0,1,0,1],
[1,0,0,0,0,0],
[1,0,0,1,1,0],
[0,0,0,1,0,0]
])
答案 0 :(得分:0)
如果批量大小可以从len(labels)
中得出:
def to_onehot(labels, n_categories, dtype=torch.float32):
batch_size = len(labels)
one_hot_labels = torch.zeros(size=(batch_size, n_categories), dtype=dtype)
for i, label in enumerate(labels):
# Subtract 1 from each LongTensor because your
# indexing starts at 1 and tensor indexing starts at 0
label = torch.LongTensor(label) - 1
one_hot_labels[i] = one_hot_labels[i].scatter_(dim=0, index=label, value=1.)
return one_hot_labels
,您有6个类别,希望输出为整数张量:
to_onehot(labels, n_categories=6, dtype=torch.int64)
tensor([[1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 1],
[1, 0, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0],
[0, 0, 0, 1, 0, 0]])
如果您以后要使用标签平滑,混合或类似的方法,我会坚持使用torch.float32
。