在numpy中随机索引多维数组

时间:2019-11-19 04:31:24

标签: python numpy multidimensional-array numpy-ndarray

如何更改此代码,避免使用python for-loop,但使用numpy函数

h, w = 2, 2
im = np.random.randint(255, size=(h, w, 3))
index = np.random.randint(3, size=(h, w))
number = np.random.randint(255, size=(h, w))
for i in range(h):
    for j in range(w):
        im[i, j, index[i, j]] += number[i, j]

2 个答案:

答案 0 :(得分:0)

下面的代码与原始代码相同,但是避免使用for循环。但是,您不清楚为什么要这样做,因为我认为该解决方案肯定比原始解决方案差。

from timeit import timeit
import numpy as np

h, w = 20, 20
im = np.random.randint(255, size=(h, w, 3))


def increase_random(x):
    result = np.copy(x)
    result[np.random.randint(3)] += np.random.randint(255)
    return result


def loops():
    index = np.random.randint(3, size=(h, w))
    number = np.random.randint(255, size=(h, w))
    for i in range(h):
        for j in range(w):
            im[i, j, index[i, j]] += number[i, j]


def vectorized():
    irv(im)


irv = np.vectorize(increase_random, signature='(n)->(n)')

print(timeit(vectorized, number=10))
print(timeit(loops, number=10))

我添加了一些时间测量结果,以表明在这种情况下,矢量化无法提高性能。在我的机器上,loops代码的速度大约快25倍。

但是,如果您正在执行的操作更简单或更复杂但更易于优化,则可能会受益于矢量化。碰巧的是,您的示例不太可能从中受益,而循环非常小且有效。

答案 1 :(得分:0)

使用与您一样小的示例,您将很难加快循环代码的速度。 numpy支付仅从某个问题规模开始产生的开销:

enter image description here

该图显示了在像素总数OP中原始循环代码(pp)与矢量化代码(w x h)的执行时间。

它是使用以下命令生成的:

from simple_benchmark import BenchmarkBuilder, MultiArgument
import numpy as np
from scipy.misc import face

B = BenchmarkBuilder()

@B.add_function()
def OP(im,index,number):
    im = im.copy()
    h,w,_ = im.shape
    for i in range(h):
        for j in range(w):
            im[i, j, index[i, j]] += number[i, j]
    return im

@B.add_function()
def pp(im,index,number):
    im = im.copy()
    h,w,_ = im.shape
    h,w = np.ogrid[:h,:w]
    im[h,w,index] += number
    return im

@B.add_arguments('#pixels')
def argument_provider():
    im = face()
    h,w,_ = im.shape
    mh,mw = h//2,w//2
    for exp in range(-8,1):
        fr = 2.**exp
        dh,dw = int(fr*mh),int(fr*mw)
        index = np.random.randint(3, size=(2*dh, 2*dw))
        number = np.random.randint(255, size=(2*dh, 2*dw),dtype=im.dtype)
        yield 4*dh*dw,MultiArgument([im[mh-dh:mh+dh,mw-dw:mw+dw],index,number])

r = B.run()
r.plot()

import pylab
pylab.savefig('randomchannel.png')