PyTorch:TypeError:'int' 对象不可下标

时间:2021-05-23 10:51:55

标签: python pytorch

我是 pytorch 的新手,我正在为多层感知器创建一个热编码函数,但我遇到了一些问题。代码如下:

def one_hot_encoding(label):
    for idx, val in enumerate(label):
        one_hot_outputs = [0]*len(label)
        idx_n = idx[val]
        one_hot_outputs[idx_n] = 1
return one_hot_outputs

我有一个类型错误说:

in one_hot_encoding(label)
  2     for idx, val in enumerate(label):
  3         one_hot_outputs = [0]*len(label)
> 4         idx_n = idx[val]
  5         one_hot_outputs[idx_n] = 1
  6     return one_hot_outputs
TypeError: 'int' object is not subscriptable

有什么帮助吗?

1 个答案:

答案 0 :(得分:0)

这不是一个直接的答案,而是一个替代方案。 PyTorch 已经具有以下功能:torch.nn.functional.one_hot。因此,如果您有一个标签张量 labeln 类,只需调用:

torch.nn.functional.one_hot(label, num_classes=n)