我有一个numpy数组I
,用于存储大小为N
(像素数)的P
张图片。每张图片的大小为P = q*q
。
N = 1000 # number of images
q = 10 # length and width of image
P = q*q # pixels of image
I = np.ones((N,P)) # array with N images of size P
现在,我想删除所选索引ps
周围大小为IDX
的补丁(将所有值设置为零)。
ps = 2 # patch size (ps x ps)
IDX = np.random.randint(0,P,(N,1))
我的方法是使用reshape(q,q)
重塑每个图像并删除IDX
周围的像素。我有问题,我不知道如何计算给定IDX
的图像内部的位置。另外,我必须检查索引是否不在图像之外。
如何解决这个问题,有没有办法对这个程序进行矢量化?
编辑:
在@Brenlla的帮助下,我执行了以下操作来删除补丁。我的方法的问题是,它需要三个for循环,我必须重塑每个图像两次。有没有办法提高性能?这部分显着减慢了我的代码。
import numpy as np
import matplotlib.pyplot as plt
def myplot(I):
imgs = 10
for i in range(imgs**2):
plt.subplot(imgs,imgs,(i+1))
plt.imshow(I[i].reshape(q,q), interpolation="none")
plt.axis("off")
plt.show()
N = 10000
q = 28
P = q*q
I = np.random.rand(N,P)
ps = 3
IDX = np.random.randint(0,P,(N,1))
for i in range(N):
img = I[i].reshape(q,q)
y0, x0 = np.unravel_index(IDX[i,0],(q,q))
for x in range(ps):
for y in range(ps):
if (x0+x < q) and (y0+y < q):
img[x0+x,y0+y] = 2.0
I[i] = img.reshape(1,q*q)
myplot(I)
答案 0 :(得分:1)
是的,可以这样做,但它涉及大量使用np.broadcasting。
生成数据以及OP_SELL
I
现在运行循环解决方案。我切换了import time
N = 10000
q = 28
P = q*q
ps = 3
I = np.random.rand(N,P)
IDX = np.random.randint(0,P,(N,1))
I_copy = I.copy()
和x0
:
y0
约在我的机器上276毫秒。现在广播:
t0=time.clock()
for i in range(N):
img = I[i].reshape(q,q)
x0, y0 = np.unravel_index(IDX[i,0],(q,q))
for x in range(ps):
for y in range(ps):
if (x0+x < q) and (y0+y < q):
img[x0+x,y0+y] = 2.0
I[i] = img.reshape(1,q*q)
print('With loop: {:.2f} ms'.format(time.clock()*1e3-t0*1e3))
大约快80倍