我想知道如何以更有效的方式对任意数组执行此操作,该代码是用PyTorch编写的,但仅适用于1-d张量。 谢谢!
test=[]
data=np.random.uniform(0,1,[20,])
x=torch.from_numpy(data).float()
x,_=torch.sort(x)
v=torch.rand(5).float()
v,_=torch.sort(v)
for i in range(len(x)):
if x[i] < v[0]:
test.append(v[0])
elif x[i] < v[1]:
test.append(v[1])
elif x[i] < v[2]:
test.append(v[2])
elif x[i] < v[3]:
test.append(v[3])
else:
test.append(v[4])
test
答案 0 :(得分:2)
您可以使用内置功能next:
for i in x:
test.append(next((e for e in v[:4] if i < e), v[4]))
您还可以使用列表推导代替for
循环:
s = v[:4]
d = v[4]
test = [next((e for e in s if i < e), d)) for i in x]
如果test
变量已经具有某些元素,则可以使用i n-place assignment +=
运算符:
test += [next((e for e in s if i < e), d) for i in x]
答案 1 :(得分:0)
对功能的适度更改:
def foo0(x,v):
test = []
for i in x:
if i<v[0]:
test.append(v[0])
elif i<v[1]:
test.append(v[1])
elif i<v[2]:
test.append(v[2])
elif i<v[3]:
test.append(v[3])
else:
test.append(v[4])
return test
使用整数数组(也进行排序)对其进行测试,以便于比较:
In [152]: x = np.arange(20); v = np.arange(0,20,4)
In [153]: x,v
Out[153]:
(array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19]), array([ 0, 4, 8, 12, 16]))
In [154]: foo0(x,v)
Out[154]: [4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 16]
In [155]: timeit foo0(x,v)
21.2 µs ± 471 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
您的版本在range
上进行迭代,并在29处多次使用x[i]
。
numpy
版本是:
def fooN(x,v):
v1 = v.copy()
v1[-1] = np.max(x)+1
temp = x[:,None]<v1
idx = np.argmax(temp, axis=1)
return v[idx].tolist()
In [158]: fooN(x,v)
Out[158]: [4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 16]
In [159]: timeit fooN(x,v)
27.7 µs ± 842 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
对于这个小样本,它比较慢。但是,使用更大的样本,速度更快。但是,它仍然需要使用较大的阵列进行一些调试。因此,我将其作为概念验证而非最终答案。
将列表传递给函数可以使其更快
In [185]: %%timeit x1=x.tolist(); v1=v.tolist()
...: foo0(x1,v1)
5.12 µs ± 8.78 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这是您的函数的版本,它更紧凑,更通用-但速度较慢
def foo01(x,v):
# getting else v[4] part requires some extra effort
test = []
for i in x:
done = False
for j in v[:-1]:
if i<j:
test.append(j)
done = True
break
if not done:
test.append(v[-1])
return test
===
对于较大的x
(v
仍然是5个元素):
In [215]: x = np.arange(2000); v = np.arange(0,2000,400)
In [216]: v
Out[216]: array([ 0, 400, 800, 1200, 1600])
In [217]: np.array(foo0(x,v))
Out[217]: array([ 400, 400, 400, ..., 1600, 1600, 1600])
In [218]: np.array(fooN(x,v))
Out[218]: array([ 400, 400, 400, ..., 1600, 1600, 1600])
In [219]: timeit foo0(x,v)
1.95 ms ± 24.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [220]: timeit foo01(x,v)
3.14 ms ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [221]: timeit fooN(x,v)
147 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)