PyTorch,逐元素应用不同的功能

时间:2019-10-24 02:03:37

标签: python pytorch

我定义了这样的张量

t_shape = [4, 1]
data = torch.rand(t_shape)

我想对每行应用不同的功能。

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]  # each function for each row.

我可以使用以下代码完成

d = torch.tensor([f(data[i]) for i, f in enumerate(funcs)])

如何使用PyTorch中定义的更高级的API以正确的方式完成操作?

1 个答案:

答案 0 :(得分:1)

我认为您的解决方案很好。但这不适用于任何张量形状。您可以按如下所示稍微修改解决方案。

t_shape = [4, 10, 10]
data = torch.rand(t_shape)

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]

# only change the following 2 lines
d = [f(data[i]) for i, f in enumerate(funcs)]
d = torch.stack(d, dim=0)