具有不同维度索引数组的索引pytorch张量

时间:2019-01-05 16:59:49

标签: python numpy indexing pytorch

我具有以下功能,该功能可以使用numpy.array进行操作,但是由于索引错误而在喂入torch.Tensor时会中断。

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
        num_dims = len(output.shape)
        print(arr[mesh].shape)  # <-- This is wrong/different to numpy!
        output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
        return output

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

问题是arr[mesh]的尺寸不同,具体取决于numpy和割炬。显然,pytorch不支持使用维数与被索引数组不同的索引数组进行索引。理想情况下,以下方法应该起作用:

features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())

但是尺寸不同:

(2, 3, 3, 3)
torch.Size([3, 3])

(从逻辑上)导致错误的原因:

Traceback (most recent call last):
  File "/home/XXX/util.py", line 226, in <module>
    torch_combs = combination_matrix(features)
  File "/home/XXX/util.py", line 218, in combination_matrix
    return torch_combination_matrix()
  File "/home/XXX/util.py", line 212, in torch_combination_matrix
    output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute

我如何将火炬行为与numpy相匹配? 我已经在火炬论坛(例如this one with only one dimension)上阅读了各种问题,但可以在此处找到如何应用的方法。同样,index_select仅适用于一维,但我需要它至少适用于二维。

1 个答案:

答案 0 :(得分:1)

这实际上很容易做到。您只需要展平索引,然后重塑形状和排列尺寸。 这是完整的工作版本:

import torch
import numpy as np


def combination_matrix(arr):
    idxs = np.arange(len(arr))
    idx = np.ix_(idxs, idxs)
    mesh = np.stack(np.meshgrid(idxs, idxs))

    def np_combination_matrix():
        output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
        num_dims = len(output.shape)
        output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
        return output

    def torch_combination_matrix():
        output_shape = (2, len(arr), len(arr), *arr.shape[1:])  # Note that this is different to numpy!
        return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))

    if isinstance(arr, np.ndarray):
        return np_combination_matrix()
    elif isinstance(arr, torch.Tensor):
        return torch_combination_matrix()

我使用pytest在不同维度的随机数组上运行此函数,它似乎在所有情况下均有效:

import pytest

@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
    dim_size = np.random.randint(1, 40, size=random_dims)
    elements = np.random.random(size=dim_size)
    np_combs = combination_matrix(elements)
    features = torch.from_numpy(elements)
    torch_combs = combination_matrix(features)

    assert np.array_equal(np_combs, torch_combs.numpy())

if __name__ == '__main__':
    pytest.main(['-x', __file__])