将张量转换为索引的一个热编码张量

时间:2019-06-09 09:46:14

标签: pytorch one-hot-encoding

我有形状为(1,1,128,128,128)的标签张量,其中值的范围可能为0.24。我想使用nn.fucntional.one_hot函数将其转换为一个热编码张量

n = 24
one_hot = torch.nn.functional.one_hot(indices, n)

但是,坦白地说,这需要一个指数张量,我不确定如何获得这些指数。我唯一拥有的张量是上述形状的标签张量,它包含1-24范围内的值,而不是索引

如何从张量中获取索引张量?预先感谢。

2 个答案:

答案 0 :(得分:0)

如果您遇到的错误是此错误:

Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.

也许您只需要转换为int64

import torch

# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32

n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64

您也可以使用indices.long()

答案 1 :(得分:0)

如果您的标签存储在列表或numpy数组中,则torch.as_tensor函数也可能会有所帮助:

import torch
import random

n_classes = 5
n_samples = 10

# Create list n_samples random labels (can also be numpy array)
labels = [random.randrange(n_classes) for _ in range(n_samples)]
# Convert to torch Tensor
labels_tensor = torch.as_tensor(labels)
# Create one-hot encodings of labels
one_hot = torch.nn.functional.one_hot(labels_tensor, num_classes=n_classes)
print(one_hot)

输出one_hot的形状为(n_samples, n_classes),其外观应类似于:

tensor([[0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 0]])