修改函数内的numpy数组?

时间:2013-03-04 15:20:47

标签: python arrays numpy pass-by-reference

我对下面的简单程序有疑问:

def my_function(my_array = np.zeros(0)):
    my_array = [1, 2, 3]

my_array = np.zeros(0)
my_function(my_array)
print my_array

它打印一个空数组,就像复制传递my_array而不是函数内部的引用一样。如何纠正?

3 个答案:

答案 0 :(得分:8)

传递引用模型更像是指针的传值。因此,在my_function中,您有一个指向原始my_array的指针的副本。如果您要使用该指针直接操作输入的数组,这将导致更改,但重新分配复制的指针不会影响原始数组。

举个例子:

def my_func(a):
    a[1] = 2.0

ar = np.zeros(4)
my_func(ar)
print ar

上述代码将更改ar的内部值。

答案 1 :(得分:4)

您可以像在列表中一样使用切片分配:

def func(my_array):
    my_array[:3] = [1,2,3]

请注意,这仍然要求my_array中至少包含3个元素...示例用法:

>>> def func(my_array):
...     my_array[:3] = [1,2,3]
... 
>>> a = np.zeros(4)
>>> a
array([ 0.,  0.,  0.,  0.])
>>> func(a)
>>> a
array([ 1.,  2.,  3.,  0.])

你缺少的是python如何处理引用。输入my_function时,您将引用绑定到名称my_array的原始ndarray对象。但是,只要您分配该名称的新内容,就会丢失原始引用,并将其替换为对新对象的引用(在本例中为列表)。

请注意,拥有一个可变对象的默认参数通常可以lead to surprises

答案 2 :(得分:4)

np.zeros(0)为您提供一个空的numpy数组。你的函数内部的引用现在指向一个新的Python列表,但你实际上没有修改过你的空numpy数组,所以你仍然在打印出来。

建议阅读this answer以澄清一些概念。