numpy旋转矩阵乘法

时间:2019-01-02 14:11:34

标签: python numpy matrix rotational-matrices

我想使用numpy计算并旋转一系列旋转矩阵。我已经编写了这段代码来完成我的工作,

def npmat(angle_list):
    aa = np.full((nn, n, n),np.eye(n))
    c=0
    for j in range(1,n):
        for i in range(j):
            th = angle_list[c]
            aa[c,i,i]=aa[c,j,j] = np.cos(th)
            aa[c,i,j]= np.sin(th)
            aa[c,j,i]= -np.sin(th)
            c+=1
    return np.linalg.multi_dot(aa)

n,nn=3,3
#nn=n*(n-1)/2
angle_list= array([1.06426904, 0.27106789, 0.56149785])

npmat(angle_list)=
array([[ 0.46742875,  0.6710055 ,  0.57555363],
       [-0.84250501,  0.53532228,  0.06012796],
       [-0.26776049, -0.51301235,  0.81555052]])

但是我必须将此功能应用超过10K次,这非常慢,感觉好像没有充分利用numpy的潜力。是否有更有效的方法在numpy中执行此操作?

2 个答案:

答案 0 :(得分:0)

编辑:由于似乎您正在寻找这些矩阵的乘积,因此可以在不构造它们的情况下应用这些矩阵。仅计算余弦和正弦而不先进行矢量化也可能很有意义。

n=3
nn= n*(n-1)//2

theta_list = np.array([1.06426904, 0.27106789, 0.56149785])

sin_list = np.sin(theta_list)
cos_list = np.cos(theta_list)
A = np.eye(n)
c=0
for i in range(1,n):
    for j in range(i):
        ri = np.copy(A[i])
        rj = np.copy(A[j])

        A[i] = cos_list[c]*ri + sin_list[c]*rj
        A[j] = -sin_list[c]*ri + cos_list[c]*rj
        c+=1

print(A.T) // transpose at end because its faster to update A[i] than A[:,i]

如果要显式计算每个矩阵,则此处为某些原始代码的向量化版本。

n=4
nn= n*(n-1)//2

theta_list = np.random.rand(nn)*2*np.pi

sin_list = np.sin(theta_list)
cos_list = np.cos(theta_list)

aa = np.full((nn, n, n),np.eye(n))
ii,jj = np.tril_indices(n,k=-1)
cc = np.arange(nn)

aa[cc,ii,ii] = cos_list[cc]
aa[cc,jj,jj] = cos_list[cc]
aa[cc,ii,jj] = -sin_list[cc]
aa[cc,jj,ii] = sin_list[cc]

答案 1 :(得分:0)

向量化程度更高的解决方案:

0.5

似乎相当快:

def npmats(angle):
    a,b = angle.shape
    aa = np.full((a,b, n,n),np.eye(n))
    for j in range(1,n):
        for i in range(j):
            aa[:,:,i,i]=aa[:,:,j,j] = np.cos(angle)
            sinangle=np.sin(angle)
            aa[:,:,i,j]= sinangle
            aa[:,:,j,i]= -sinangle
    bb=np.empty((a,n,n))
    for i in range(a):
        bb[i]=np.linalg.multi_dot(aa[i])
    return bb