Numpy调整大小而不修改原始对象

时间:2015-05-10 03:23:28

标签: python numpy

我想扩展两个表的维度,以便我可以使用numpy的广播乘法。我使用了以下代码:

def tableResize(table1,table2,var1,var2):
    n1=[1]*len(var1)
    n2=[1]*len(var2)
    table1.resize(list(table1.shape)+n2)
    table2.resize(n1+list(table2.shape))
    return table1,table2

假设table1为2 * 3,table2为3 * 4,扩展表为2 * 3 * 1 * 1和1 * 1 * 3 * 4。 虽然我注意到我可以写

table1[:,:,np.newaxis,np.newaxis]*table2[np.newaxis,np.newaxis,:,:]

这对table1和table2本身没有任何影响。但我不知道如何自动生成列表[:,:,np.newaxis,np.newaxis]

但是,resize方法没有返回值,它将修改对象本身。我不想使用deepcopy。有人有想法吗?非常感谢^ _ ^

2 个答案:

答案 0 :(得分:2)

  

但是,resize方法没有返回值,它将修改对象本身。

是的,但如果您查看the docs for the resize method,它会给您答案:resize function

NumPy中有很多这样的对,其中np.spam(a, eggs)制作了a的垃圾邮件副本,而a.spam(eggs)就地传播了a。如果您查看文档,它们将被链接在一起。

所以,我认为你在寻找的是:

t1 = np.resize(table1, list(table1.shape)+n2)
t2 = np.resize(table2, n1+list(table2.shape))
return t1, t2

答案 1 :(得分:1)

a1 = a.reshape(...)返回一个视图 - a1有一个新形状,但共享数据缓冲区。

a1 = table1.reshape(table1.shape+(1,)*table2.ndim)
 # (2, 3, 1, 1)

b1 = table2.reshape((1,)*table1.ndim+table2.shape)
 # (1, 1, 3, 4)

由于维度在开始时根据需要添加,因此table2不需要展开。

a1 + b1 == a1 + table2

请查看np.atleast_3dnp.broadcast_arrays,了解有关如何扩展数组维度的其他建议。

进一步观察resize,我会说,无论是哪种形式,它都是错误的功能,当你想要做的只是添加单身尺寸时。 reshape是正确的函数/方法。要么是np.newaxis

您可以通过连接切片和[:,:,np.newaxis,np.newaxis]

来构建None
s=[slice(None)]*2 + [None]*2
# [slice(None, None, None), slice(None, None, None), None, None]

table1[s].shape
# (2, 3, 1, 1)

np.resize代码:

File:        /usr/lib/python3/dist-packages/numpy/core/fromnumeric.py
def resize(a, new_shape):

    if isinstance(new_shape, (int, nt.integer)):
        new_shape = (new_shape,)
    a = ravel(a)
    Na = len(a)
    if not Na: return mu.zeros(new_shape, a.dtype.char)
    total_size = um.multiply.reduce(new_shape)
    n_copies = int(total_size / Na)
    extra = total_size % Na

    if total_size == 0:
        return a[:0]

    if extra != 0:
        n_copies = n_copies+1
        extra = Na-extra

    a = concatenate( (a,)*n_copies)
    if extra > 0:
        a = a[:-extra]

    return reshape(a, new_shape)

np.reshape代码(典型的函数是数组方法的委托):

File:        /usr/lib/python3/dist-packages/numpy/core/fromnumeric.py
def reshape(a, newshape, order='C'):
    try:
        reshape = a.reshape
    except AttributeError:
        return _wrapit(a, 'reshape', newshape, order=order)
    return reshape(newshape, order=order)

比较两个函数的时间 - resize要慢得多。

In [109]: timeit np.resize(a,(2,3,1,1)).shape
10000 loops, best of 3: 41.5 µs per loop

In [110]: timeit np.reshape(a,(2,3,1,1)).shape
100000 loops, best of 3: 2.79 µs per loop

就地resize很快:

In [124]: %%timeit a1=a.copy()
a1.resize((2,3,1,1))
a1.shape
   .....: 
1000000 loops, best of 3: 799 ns per loop