我是 pytorch 的新手,我正在为多层感知器创建一个热编码函数,但我遇到了一些问题。代码如下:
def one_hot_encoding(label):
for idx, val in enumerate(label):
one_hot_outputs = [0]*len(label)
idx_n = idx[val]
one_hot_outputs[idx_n] = 1
return one_hot_outputs
我有一个类型错误说:
in one_hot_encoding(label)
2 for idx, val in enumerate(label):
3 one_hot_outputs = [0]*len(label)
> 4 idx_n = idx[val]
5 one_hot_outputs[idx_n] = 1
6 return one_hot_outputs
TypeError: 'int' object is not subscriptable
有什么帮助吗?
答案 0 :(得分:0)
这不是一个直接的答案,而是一个替代方案。 PyTorch 已经具有以下功能:torch.nn.functional.one_hot
。因此,如果您有一个标签张量 label
和 n
类,只需调用:
torch.nn.functional.one_hot(label, num_classes=n)