我有一个大小为4 x 6
的张量,其中4是批量大小,6是序列长度。序列向量的每个元素都是一些索引(0到n)。我想创建一个4 x 6 x n
张量,其中第三维中的向量将是索引的一个热编码,这意味着我想将1放在指定的索引中,其余的值将为零。
例如,我有以下张量:
[[5, 3, 2, 11, 15, 15],
[1, 4, 6, 7, 3, 3],
[2, 4, 7, 8, 9, 10],
[11, 12, 15, 2, 5, 7]]
这里,所有值都介于(0到n)之间,其中n = 15.因此,我想将张量转换为4 X 6 X 16
张量,其中第三维将代表一个热编码向量。
如何使用PyTorch功能实现这一目标?现在,我正在使用循环,但我想避免循环!
答案 0 :(得分:5)
新答案
从PyTorch 1.1开始,one_hot
中有一个torch.nn.functional
函数。给定索引indices
和最大索引n
的任何张量,您可以创建one_hot版本,如下所示:
n = 5
indices = torch.randint(0,n, size=(4,7))
one_hot = torch.nn.functional.one_hot(indices, n) # size=(4,7,n)
很老的回答
目前,根据我的经验,切片和索引在PyTorch中可能会有点痛苦。我假设你不想将你的张量转换为numpy数组。我现在能想到的最优雅的方法是使用稀疏张量然后转换为密集张量。这将如下工作:
from torch.sparse import FloatTensor as STensor
batch_size = 4
seq_length = 6
feat_dim = 16
batch_idx = torch.LongTensor([i for i in range(batch_size) for s in range(seq_length)])
seq_idx = torch.LongTensor(list(range(seq_length))*batch_size)
feat_idx = torch.LongTensor([[5, 3, 2, 11, 15, 15], [1, 4, 6, 7, 3, 3],
[2, 4, 7, 8, 9, 10], [11, 12, 15, 2, 5, 7]]).view(24,)
my_stack = torch.stack([batch_idx, seq_idx, feat_idx]) # indices must be nDim * nEntries
my_final_array = STensor(my_stack, torch.ones(batch_size * seq_length),
torch.Size([batch_size, seq_length, feat_dim])).to_dense()
print(my_final_array)
注意:PyTorch目前正在进行一些工作,将在接下来的两到三周内添加numpy风格的广播和其他功能以及其他功能。所以有可能,在不久的将来会有更好的解决方案。
希望这会对你有所帮助。
答案 1 :(得分:3)
可以PyTorch
使用就地scatter_
方法为任何Tensor
对象完成此操作。
labels = torch.LongTensor([[[2,1,0]], [[0,1,0]]]).permute(0,2,1) # Let this be your current batch
batch_size, k, _ = labels.size()
labels_one_hot = torch.FloatTensor(batch_size, k, num_classes).zero_()
labels_one_hot.scatter_(2, labels, 1)
对于num_classes=3
(指数应与[0,3)
不同),这会给你
(0 ,.,.) =
0 0 1
0 1 0
1 0 0
(1 ,.,.) =
1 0 0
0 1 0
1 0 0
[torch.FloatTensor of size 2x3x3]
请注意,labels
应为torch.LongTensor
。
PyTorch文档参考:torch.Tensor.scatter_
答案 2 :(得分:1)
我找到的最简单的方法。其中x是数字列表,而class_count是您拥有的类的数量。
def one_hot(x, class_count):
return torch.eye(class_count)[x,:]
像这样使用它:
x = [0,2,5,4]
class_count = 8
one_hot(x,class_count)
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0.]])