我想用numpy和numba计算一些东西。这是我的代码:
import numpy as np
from numba import jit,double,int64
@jit(locals=dict(i=int64,j=int64,k=int64,l=int64,suma=double))
def omega_comp_arrays(omega,p_kl,eta,theta,K,L,links_by_ratings_array):
#new_omega = np.zeros(omega.shape)
for rating,links in enumerate(links_by_ratings_array):
for i,j in links:
suma = 0
for k in range(K):
for l in range(L):
omega[i,j,k,l] = p_kl[k,l,rating]*theta[i,k]*eta[j,l]
suma += omega[i,j,k,l]
omega[i,j,:,:] /= suma
return omega
N_nodes=2
N_veins=[1,1]
N_items=2
N_ratings=1
K=2
L=2
##Definim matrius
theta = np.random.rand(N_nodes,K)
eta = np.random.rand(N_items,L)
p_kl = np.random.rand(K,L,N_ratings)
suma = np.sum(theta,axis =1)
theta /=suma[:,np.newaxis]
suma = np.sum(eta,axis=1)
eta /= suma[:,np.newaxis]
suma = np.sum(p_kl,axis =2)
p_kl /=suma[:,:,np.newaxis]
links_by_ratings_array = [np.array([0,0])]
omega = np.ones((N_nodes,N_items,K,L))
omega = omega_comp_arrays(omega,p_kl,eta,theta,K,L,links_by_ratings_array)
运行代码时出现问题:
Traceback (most recent call last):
File "test_omega.py", line 39, in <module>
omega = omega_comp_arrays(omega,p_kl,eta,theta,K,L,links_by_ratings_array)
TypeError: 'numpy.int64' object is not iterable
Exception TypeError: "'NoneType' object is not callable" in <bound method ModuleRef.__del__ of <llvmlite.binding.module.ModuleRef object at 0x7f801123d3d0>> ignored
但如果我激活nopython模式,则会出现另一个错误:
Traceback (most recent call last):
File "test_omega.py", line 39, in <module>
omega = omega_comp_arrays(omega,p_kl,eta,theta,K,L,links_by_ratings_array)
File "/usr/local/lib/python2.7/dist-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
numba.errors.TypingError: Caused By:
Traceback (most recent call last):
File "/usr/local/lib/python2.7/dist-packages/numba/compiler.py", line 240, in run
stage()
File "/usr/local/lib/python2.7/dist-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/usr/local/lib/python2.7/dist-packages/numba/compiler.py", line 881, in type_inference_stage
infer.propagate()
File "/usr/local/lib/python2.7/dist-packages/numba/typeinfer.py", line 846, in propagate
raise errors[0]
TypingError: failed to unpack int64
File "test_omega.py", line 10
[1] During: typing of exhaust iter at test_omega.py (10)
Failed at nopython (nopython frontend)
failed to unpack int64
File "test_omega.py", line 10
[1] During: typing of exhaust iter at test_omega.py (10)
换句话说,我不能对数组列表进行循环,因为错误在链接循环中(对于i,j在链接中)。有什么建议吗?