广播视图不规则地numpy

时间:2018-06-01 07:47:16

标签: python numpy numpy-broadcasting

假设我想要一个大小为(n,m)的numpy数组,其中n非常大,但有很多重复,即。 0:n1相同,n1:n2相同等(n2%n1!=0,即不是常规间隔)。有没有办法只为每个重复项存储一组值,同时拥有整个数组的视图?

示例:

unique_values = np.array([[1,1,1], [2,2,2] ,[3,3,3]]) #these are the values i want to store in memory
index_mapping = np.array([0,0,1,1,1,2,2]) # a mapping between index of array above, with array below

unique_values_view = np.array([[1,1,1],[1,1,1],[2,2,2],[2,2,2],[2,2,2], [3,3,3],[3,3,3]]) #this is how I want the view to look like for broadcasting reasons

我计划将数组(视图)乘以一些大小为(1,m)的其他数组,并采用此产品的点积:

other_array1 = np.arange(unique_values.shape[1]).reshape(1,-1) # (1,m)
other_array2 = 2*np.ones((unique_values.shape[1],1)) # (m,1)
output = np.dot(unique_values_view * other_array1, other_array2).squeeze()

输出是长度为n的一维数组。

2 个答案:

答案 0 :(得分:6)

根据您的示例,您可以简单地将通过计算的索引映射计算到最后:

<?xml version="1.0" encoding="UTF-8"?>
<configuration>
    <include resource="org/springframework/boot/logging/logback/base.xml"/>
    <logger name="org.springframework.web" level="DEBUG"/>
</configuration>

答案 1 :(得分:1)

你的表达允许两个重要的优化:

  • 做最后的索引
  • 先将other_array1other_array2相乘,然后再将其与unique_values相乘  >>> output_pp = (unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping] # check for correctness >>> (output == output_pp).all() True # and compare it to @Yakym Pirozhenko's approach >>> from timeit import timeit >>> print("yp:", timeit("np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]", globals=globals())) yp: 3.9105667411349714 >>> print("pp:", timeit("(unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]", globals=globals())) pp: 2.2684884609188884

让我们应用这些优化:

A

如果我们观察到两件事情,很容易发现这些优化:

(1)如果mxnb - 矩阵且nA * b == A @ diag(b) A.T * b[:, None] == diag(b) @ A.T - 矢量那么

A

(2)如果mxnI - 矩阵且krange(m) - 来自的整数向量        A[I] == onehot(I) @ A 然后

onehot

def onehot(I, m, dtype=int): out = np.zeros((I.size, m), dtype=dtype) out[np.arange(I.size), I] = 1 return out 可以定义为

uv

使用这些事实并缩写imoa1oa2uv[im] * oa1 @ oa2 == onehot(im) @ uv @ diag(oa1) @ oa2 我们可以写

onehot(im) @ (uv @ (diag(oa1) @ oa2))

上述优化现在只需选择这些矩阵乘法的最佳阶数

.exclude(..)

使用(1)和(2)向后,我们从本文开头获得优化表达式。