假设我想要一个大小为(n,m)
的numpy数组,其中n
非常大,但有很多重复,即。 0:n1
相同,n1:n2
相同等(n2%n1!=0
,即不是常规间隔)。有没有办法只为每个重复项存储一组值,同时拥有整个数组的视图?
示例:
unique_values = np.array([[1,1,1], [2,2,2] ,[3,3,3]]) #these are the values i want to store in memory
index_mapping = np.array([0,0,1,1,1,2,2]) # a mapping between index of array above, with array below
unique_values_view = np.array([[1,1,1],[1,1,1],[2,2,2],[2,2,2],[2,2,2], [3,3,3],[3,3,3]]) #this is how I want the view to look like for broadcasting reasons
我计划将数组(视图)乘以一些大小为(1,m)
的其他数组,并采用此产品的点积:
other_array1 = np.arange(unique_values.shape[1]).reshape(1,-1) # (1,m)
other_array2 = 2*np.ones((unique_values.shape[1],1)) # (m,1)
output = np.dot(unique_values_view * other_array1, other_array2).squeeze()
输出是长度为n
的一维数组。
答案 0 :(得分:6)
根据您的示例,您可以简单地将通过计算的索引映射计算到最后:
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<include resource="org/springframework/boot/logging/logback/base.xml"/>
<logger name="org.springframework.web" level="DEBUG"/>
</configuration>
答案 1 :(得分:1)
你的表达允许两个重要的优化:
other_array1
与other_array2
相乘,然后再将其与unique_values
相乘
>>> output_pp = (unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]
# check for correctness
>>> (output == output_pp).all()
True
# and compare it to @Yakym Pirozhenko's approach
>>> from timeit import timeit
>>> print("yp:", timeit("np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]", globals=globals()))
yp: 3.9105667411349714
>>> print("pp:", timeit("(unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]", globals=globals()))
pp: 2.2684884609188884
让我们应用这些优化:
A
如果我们观察到两件事情,很容易发现这些优化:
(1)如果mxn
是b
- 矩阵且n
是A * b == A @ diag(b)
A.T * b[:, None] == diag(b) @ A.T
- 矢量那么
A
(2)如果mxn
是I
- 矩阵且k
是range(m)
- 来自的整数向量
A[I] == onehot(I) @ A
然后
onehot
def onehot(I, m, dtype=int):
out = np.zeros((I.size, m), dtype=dtype)
out[np.arange(I.size), I] = 1
return out
可以定义为
uv
使用这些事实并缩写im
,oa1
,oa2
和uv[im] * oa1 @ oa2 == onehot(im) @ uv @ diag(oa1) @ oa2
我们可以写
onehot(im) @ (uv @ (diag(oa1) @ oa2))
上述优化现在只需选择这些矩阵乘法的最佳阶数
.exclude(..)
使用(1)和(2)向后,我们从本文开头获得优化表达式。