火炬计数张量

时间:2020-08-07 22:31:55

标签: python pytorch

我正在尝试创建一个Python函数,该函数在给定输入dimm的情况下生成大小为[m ** dim, dim]的形式的张量

[[1,...,1,1],
 [1,...,1,2],
 ...
 [1,...,1,m],
 [1,...,2,1],
 [1,...,2,2],
 ...
 [m,...,m,m]]

在Pytorch中最好的方法是什么?

谢谢

2 个答案:

答案 0 :(得分:2)

那是一个相当有趣的问题!首先,让我们先进行一下数学运算,如果您想跳过任意基数的整数运算,请参见下面的示例代码。

首先,让我们注意到,您在基础rm中写了每一行r = i_{d-1} m**(d-1) + i_{d-2} m**(d-2) + ... + i_0 m**0的索引,然后在该行{{1}中写了j元素}的值为r1 + i_{d-j+1}j中)。该张量基本上遍历所有以基数[0, d-1]编写的整数。

话虽这么说,代码很容易附带:只需遍历所有数字(或者在基数m中对它们进行分解),然后根据此分解构建张量,并连接所有张量。

为了获得一点点效率,下面的代码而不是一行一行地逐块(大小为m * d的块)构建张量,块的最后一列始终为m

[1,2,...,m]

其中m = d = 3:

def iter_radix_m(digits, radix):
    """
    utility aux function to iterate over integers decompositions
    digits : list of size(d-1), the d-1 first digits of your number
    radix: 
    """
    index = len(digits)-1
    while digits[index] == radix-1:
        digits[index] = 0
        index -= 1
    digits[index] += 1 

def radix_tensor(m, d):
     # there are m**(d-1) blocks of size m
     nb_blocks = m**(d-1)
     # In there will be stored all blocks until final concatenation
     blocks = []
     digits = [0]*(d-1)
     # Iteration over all blocks
     for i in range(nb_blocks):
         # A column is a tensor of ones, multiplied by the corresponding digit
         # in the m-radix decomposition of i, plus 1
         # The last column is [1,2,...m] always
         cols = [torch.tensor(digits, dtype=int) * torch.ones(m,d-1, dtype=int) + 1] + [torch.arange(1,m+1).view(m,1)]
         #Concatenate these columns to make an (m,d) block
         blocks += [torch.cat(cols, dim=1)]
         iter_radix_m(digits, m)
     # Concatenate all blocks to make an (m**d, d) tensor
     return torch.cat(blocks, dim=0)

答案 1 :(得分:1)

我已经使用以下代码解决了这个问题:

import torch
import numpy as np

def mat_gen(dim, m): 
    return torch.from_numpy(np.array(np.meshgrid(*[np.arange(1, m + 1, 1) for i in range(dim)])).T.reshape(m ** dim, dim))

这是另一个仅使用Pytorch的功能:

import torch

def mat_gen(dim, m): 
    return torch.stack(torch.meshgrid(*[torch.arange(1, m + 1, 1) for i in range(dim)])).T.reshape(m ** dim, dim)