在z方向上的Numpy广播

时间:2017-12-31 11:53:26

标签: python numpy broadcasting

在机器学习环境中,我需要进行每元素乘法运算。为了有效地做到这一点,我需要以特定方式广播3D张量的元素,使得每个2x2矩阵重复n次,如以下示例所示,其中n = 2:

import numpy as np

a = np.linspace(1,12,12)
a = a.reshape(3,2,2)

# what to put here?
<some statements>

print a

# result:
[[[  1.   2.]
  [  3.   4.]]

 [[  1.   2.]
  [  3.   4.]]

 [[  5.   6.]
  [  7.   8.]]

 [[  5.   6.]
  [  7.   8.]]

 [[  9.  10.]
  [ 11.  12.]]

 [[  9.  10.]
  [ 11.  12.]]]

什么声明会起作用?

谢谢!

1 个答案:

答案 0 :(得分:3)

在您a作为3D数组后,np.repeatnp.broadcast_to向第一个轴复制的那个 -

N = 2 # replication number
out = np.repeat(a,N,axis=0)

或者,对于4D只读输出,我们可以使用RecyclerView创建一个非常有效的视图,因为我们不会占用任何额外的内存,例如所以 -

m,n,r = a.shape
out = np.broadcast_to(a[:,None],(m,N,n,r))

# Confirm it's a view
In [32]: np.shares_memory(a, out)
Out[32]: True