将维度添加到numpy矩阵

时间:2018-04-04 15:29:40

标签: python numpy

为什么下面的代码没有向矩阵添加第三个维度?我应该如何更改代码以获得我需要的东西?

imm = np.zeros(shape=(10,10))
dims = imm.shape
if len(dims) < 3:
    imm.reshape((dims[0], dims[1], 1))
imm.shape # I want this to pring (10,10,1)

5 个答案:

答案 0 :(得分:2)

向numpy数组添加维度的最佳方法是

imm = imm[..., np.newaxis]
不用担心,这不是做任何昂贵的复制或任何事情,底层内存仍然存在*。所有这一切都是ndim增加1。

我认为这个解决方案更清晰,因为您不必担心抓住shape并将其放入reshape。它还具有以下优点:您可以快速将newaxis置于任何维度。假设您希望维度0或1成为新维度

imm = imm[np.newaxis, ...]  # new dimension is axis 0
imm = imm[:, np.newaxis, :] # new dimension is axis 1
imm = imm[:, :, np.newaxis] # new dimension is axis 2 (same as above)

*你可以通过

向自己证明这一点
x = np.array([1,2,3])
y = x[..., np.newaxis]
x *= 10
print(y)
# 
# [[10]
#  [20]
#  [30]]

答案 1 :(得分:1)

你可以这样做:

item.data

和 重塑不起作用。您的代码无效,因为您没有将reshape返回的值分配给任何变量。

答案 2 :(得分:1)

imm.reshape不执行就地修改。相反,它返回修改后的数组。

因此,执行imm = imm.reshape((dims[0], dims[1], 1)),您将获得所需的输出。请考虑阅读documentation了解更多详情

答案 3 :(得分:1)

如果您想要修改imm,最简单的方法是:

import numpy as np
imm = np.zeros(shape=(10,10))
dims = imm.shape
if len(dims) < 3:
    imm.shape = imm.shape + (1,)
print(imm.shape) 
# (10,10,1)

您可以使用

imm.shape = imm.shape + (1,)

或者简单地说,

imm.shape += 1,

答案 4 :(得分:1)

有许多方法可以更改数组的尺寸,只要新尺寸的尺寸与旧尺寸的产品保持一致即可。

在所有情况下,只有一个,整形操作返回一个新的数组对象,但在内存中使用相同的底层数据。这意味着即使对于巨大的阵列,重塑也非常便宜。您只需将返回的引用分配回原始数组。

重新整形数组对象的唯一方法是分配给它的shape属性:

arr.shape += (1,)

添加维度的所有其他方法都返回一个新的数组对象。在大多数案例中,他们不会复制数据,除非您使用索引方案做一些奇怪的事情:

  1. 使用np.reshape

    arr = np.reshape(arr, arr.shape + (1,))
    
  2. 使用arr.reshape。这与np.reshape完全相同:

    arr = arr.reshape(arr.shape + (1,))
    
  3. 使用np.newaxis(a.k.a。None)对数组编制索引:

    arr = arr[..., np.newaxis]
    

    OR

    arr = arr[..., None]
    

    这会创建一个新的数组对象,在您放置newaxis对象的任何位置插入一个大小为1的新轴,但不会触及基础数据。

  4. 使用np.expand_dims。这将通过第二个参数在任何地方插入一个大小为1的新维度:

    arr = np.expand_dims(arr, -1)
    
  5. 对于您的特定情况,您可以使用np.atleast_3d

    arr = np.atleast_3d(arr)
    

    当给定1D输入时,此函数表现奇怪:它将大小为1的维度放在元组的开头和结尾,而不是一侧的所有维度。 2D情况完全符合您的要求。

  6. 使用np.array显式创建具有预期维数的新数组对象:

    arr = np.array(arr, copy=False, ndmin=3)
    
    必须明确说明

    copy=False以避免复制数据。第三个轴预先到形状,而不是附加,所以你可能需要做类似的事情

    arr = np.moveaxis(arr, 0, -1)
    

    到目前为止,这是满足您特定需求的最不理想选择。

  7. <强>更新

    如果我们将您的原始问题等同于图钉,这就是解决方案的核大锤:

    arr = np.lib.stride_tricks.as_strided(arr, arr.shape + (1,))
    

    请记住as_strided真的是一把核大锤!