用pytorch沿着任意维度映射功能?麻木?

时间:2018-10-12 14:53:57

标签: python numpy matrix optimization pytorch

我想知道为什么找不到实用程序来沿着复杂的张量/数组/矩阵的任何维度映射自定义pytorchnumpy转换。
我想我记得R中提供了这种功能。有了这个幻想tch.map实用程序,您可以做到:

>>> import torch as tch # or numpy

>>> # one torch tensor
>>> a = tch.tensor([0, 1, 2, 3, 4])
>>> # one torch function (dummy) returning 2 values
>>> f = lambda x: tch.tensor((x + 1, x * 2))
>>> # map f along dimension 0 of a, expecting 2 outputs
>>> res = tch.map(f, a, 0, 2) # fantasy, optimized on CPU/GPU..
>>> res
tensor([[1, 0],
        [2, 2],
        [3, 4],
        [4, 6],
        [5, 8]])
>>> res.shape
torch.Size([5, 2])

>>> # another tensor
>>> a = tch.tensor(list(range(24))).reshape(2, 3, 4).type(tch.double)
>>> # another function (dummy) returning 2 values
>>> f = lambda x: tch.tensor((tch.mean(x), tch.std(x)))
>>> # map f along dimension 2 of a, expecting 2 outputs
>>> res = tch.map(f, a, 2, 2) # fantasy, optimized on CPU/GPU..
tensor([[[ 1.5000,  1.2910],
         [ 5.5000,  1.2910],
         [ 9.5000,  1.2910]],

        [[13.5000,  1.2910],
         [17.5000,  1.2910],
         [21.5000,  1.2910]]])
>>> res.shape
torch.Size([2, 3, 2])

>>> # yet another tensor
>>> a = tch.tensor(list(range(12))).reshape(3, 4)
>>> # another function (dummy) returning 2x2 values
>>> f = lambda x: x + tch.rand(2, 2)
>>> # map f along all values of a, expecting 2x2 outputs
>>> res = tch.map(f, a, -1, (2, 2)) # fantasy, optimized on CPU/GPU..
>>> print(res)
tensor([[[[ 0.4827,  0.3043],
          [ 0.8619,  0.0505]],

         [[ 1.4670,  1.5715],
          [ 1.1270,  1.7752]],

         [[ 2.9364,  2.0268],
          [ 2.2420,  2.1239]],

         [[ 3.9343,  3.6059],
          [ 3.3736,  3.5178]]],


        [[[ 4.2063,  4.9981],
          [ 4.3817,  4.4109]],

         [[ 5.3864,  5.3826],
          [ 5.3614,  5.1666]],

         [[ 6.6926,  6.2469],
          [ 6.7888,  6.6803]],

         [[ 7.2493,  7.5727],
          [ 7.6129,  7.1039]]],


        [[[ 8.3171,  8.9037],
          [ 8.0520,  8.9587]],

         [[ 9.5006,  9.1297],
          [ 9.2620,  9.8371]],

         [[10.4955, 10.5853],
          [10.9939, 10.0271]],

         [[11.3905, 11.9326],
          [11.9376, 11.6408]]]])
>>> res.shape
torch.Size([3, 4, 2, 2])

相反,我不断发现自己被复杂的tch.stacktch.squeezetch.reshapetch.permute所困扰,数着我的手指不会迷路。

是否存在这样的实用程序,并且由于某种原因我错过了它?
这样的实用程序由于某种原因无法实现吗?

0 个答案:

没有答案