我具有以下功能,该功能可以使用numpy.array
进行操作,但是由于索引错误而在喂入torch.Tensor
时会中断。
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output = torch.zeros(len(arr), len(arr), 2, *arr.shape[1:], dtype=arr.dtype)
num_dims = len(output.shape)
print(arr[mesh].shape) # <-- This is wrong/different to numpy!
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
return output
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
问题是arr[mesh]
的尺寸不同,具体取决于numpy和割炬。显然,pytorch不支持使用维数与被索引数组不同的索引数组进行索引。理想情况下,以下方法应该起作用:
features = np.arange(9).reshape(3, 3)
np_combs = combination_matrix(features)
features = torch.from_numpy(features)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
但是尺寸不同:
(2, 3, 3, 3)
torch.Size([3, 3])
(从逻辑上)导致错误的原因:
Traceback (most recent call last):
File "/home/XXX/util.py", line 226, in <module>
torch_combs = combination_matrix(features)
File "/home/XXX/util.py", line 218, in combination_matrix
return torch_combination_matrix()
File "/home/XXX/util.py", line 212, in torch_combination_matrix
output[idx] = arr[mesh].permute(2, 1, 0, *np.arange(3, num_dims))
RuntimeError: number of dims don't match in permute
我如何将火炬行为与numpy相匹配? 我已经在火炬论坛(例如this one with only one dimension)上阅读了各种问题,但可以在此处找到如何应用的方法。同样,index_select仅适用于一维,但我需要它至少适用于二维。
答案 0 :(得分:1)
这实际上很容易做到。您只需要展平索引,然后重塑形状和排列尺寸。 这是完整的工作版本:
import torch
import numpy as np
def combination_matrix(arr):
idxs = np.arange(len(arr))
idx = np.ix_(idxs, idxs)
mesh = np.stack(np.meshgrid(idxs, idxs))
def np_combination_matrix():
output = np.zeros((len(arr), len(arr), 2, *arr.shape[1:]), dtype=arr.dtype)
num_dims = len(output.shape)
output[idx] = arr[mesh].transpose((2, 1, 0, *np.arange(3, num_dims)))
return output
def torch_combination_matrix():
output_shape = (2, len(arr), len(arr), *arr.shape[1:]) # Note that this is different to numpy!
return arr[mesh.flatten()].reshape(output_shape).permute(2, 1, 0, *range(3, len(output_shape)))
if isinstance(arr, np.ndarray):
return np_combination_matrix()
elif isinstance(arr, torch.Tensor):
return torch_combination_matrix()
我使用pytest在不同维度的随机数组上运行此函数,它似乎在所有情况下均有效:
import pytest
@pytest.mark.parametrize('random_dims', range(1, 5))
def test_combination_matrix(random_dims):
dim_size = np.random.randint(1, 40, size=random_dims)
elements = np.random.random(size=dim_size)
np_combs = combination_matrix(elements)
features = torch.from_numpy(elements)
torch_combs = combination_matrix(features)
assert np.array_equal(np_combs, torch_combs.numpy())
if __name__ == '__main__':
pytest.main(['-x', __file__])