用火炬随机选择?

时间:2019-12-23 22:14:52

标签: python numpy pytorch

我有一个int张量,并且想从中随机选择。我正在寻找np.random.choice()的等价物。

import torch

choices = torch.tensor([2, 4, 6, 7])

5 个答案:

答案 0 :(得分:2)

在我的情况下:values.shape =(386363948,2),k = 190973,下面的代码相当快地工作。大约需要0.2秒。

N, D = 386363948, 2
k = 190973
values = torch.randn(N, D)

# The following code cost 0.2 second
indice = random.sample(range(N), k)
indice = torch.tensor(indice)
sampled_values = values[indice]

但是,使用torch.randperm将花费超过20秒的时间。

# Cost more than 20 second
sampled_values = values[torch.randperm(N)[:k]]

答案 1 :(得分:0)

torch没有np.random.choice()的等效实现。最好的办法是从选择中选择一个随机索引。

choices[torch.randint(choices.shape[0], (1,))]

这将在0和张量中的元素数量之间生成一个randint

for i in range(5):
    print(choices[torch.randint(choices.shape[0], (1,))])
tensor([2])
tensor([6])
tensor([2])
tensor([6])
tensor([7])

如果要设置replacement = False,请删除使用掩码选择的值:

for i in range(10):
    value = choices[torch.randint(choices.shape[0], (1,))]
    choices = choices[choices!=value]
    print(value)
tensor([2])
tensor([4])
tensor([6])
tensor([7])

答案 2 :(得分:0)

正如其他提到的那样,火炬没有选择 您可以改用randint或置换

import torch
n = 4
choices = torch.rand(4, 3)
choices_flat = choices.view(-1)
index = torch.randint(choices_flat.numel(), (n,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:n]
select = choices_flat[index]

答案 3 :(得分:0)

torch.multinomial 提供与 numpy 的 random.choice 等效的行为(包括带/不带替换的采样):

# Uniform weights for random draw
unif = torch.ones(pictures.shape[0])

idx = unif.multinomial(10, replacement=True)
samples = pictures[idx]
samples.shape
>>> torch.Size([10, 28, 28, 3])

答案 4 :(得分:0)

试试这个:

input_tensor = torch.randn(5, 8)
print(input_tensor)
indices = torch.LongTensor(np.random.choice(5,2, replace=False)) 
output_tensor = torch.index_select(input_tensor, 0, indices)
print(output_tensor)