迭代一个numpy数组

时间:2011-08-06 14:27:04

标签: python numpy

是否有一个不那么冗长的替代方案:

for x in xrange(array.shape[0]):
    for y in xrange(array.shape[1]):
        do_stuff(x, y)

我想出了这个:

for x, y in itertools.product(map(xrange, array.shape)):
    do_stuff(x, y)

这可以节省一个缩进,但仍然非常难看。

我希望看起来像这个伪代码的东西:

for x, y in array.indices:
    do_stuff(x, y)

有这样的事吗?

4 个答案:

答案 0 :(得分:170)

我认为你正在寻找ndenumerate

>>> a =numpy.array([[1,2],[3,4],[5,6]])
>>> for (x,y), value in numpy.ndenumerate(a):
...  print x,y
... 
0 0
0 1
1 0
1 1
2 0
2 1

关于表现。它比列表理解慢一点。

X = np.zeros((100, 100, 100))

%timeit list([((i,j,k), X[i,j,k]) for i in range(X.shape[0]) for j in range(X.shape[1]) for k in range(X.shape[2])])
1 loop, best of 3: 376 ms per loop

%timeit list(np.ndenumerate(X))
1 loop, best of 3: 570 ms per loop

如果您担心性能,可以通过查看ndenumerate的实现来进一步优化,它可以完成两件事,转换为数组并循环。如果您知道有数组,则可以调用flat iterator的.coords属性。

a = X.flat
%timeit list([(a.coords, x) for x in a.flat])
1 loop, best of 3: 305 ms per loop

答案 1 :(得分:41)

如果您只需要索引,可以尝试numpy.ndindex

>>> a = numpy.arange(9).reshape(3, 3)
>>> [(x, y) for x, y in numpy.ndindex(a.shape)]
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]

答案 2 :(得分:13)

请参阅nditer

import numpy as np
Y = np.array([3,4,5,6])
for y in np.nditer(Y, op_flags=['readwrite']):
    y += 3

Y == np.array([6, 7, 8, 9])
  

y = 3不起作用,请改用y *= 0y += 3

答案 3 :(得分:0)

我发现这里没有使用 numpy.nditer() 的好描述。所以,我要带一个。 根据 NumPy v1.21 dev0 manual 的说法,NumPy 1.6 中引入的迭代器对象 nditer 提供了许多灵活的方法来以系统的方式访问一个或多个数组的所有元素。

我必须计算 mean_squared_error 并且我已经计算了 y_predicted 并且我有来自波士顿数据集的 y_actual,可通过 sklearn 获得。

def cal_mse(y_actual, y_predicted):
    """ this function will return mean squared error
       args:
           y_actual (ndarray): np array containing target variable
           y_predicted (ndarray): np array containing predictions from DecisionTreeRegressor
       returns:
           mse (integer)
    """
    sq_error = 0
    for i in np.nditer(np.arange(y_pred.shape[0])):
        sq_error += (y_actual[i] - y_predicted[i])**2
    mse = 1/y_actual.shape[0] * sq_error
    
    return mse

希望这有帮助:)。进一步说明visit