如何将数组复制/重复N次到新数组?

时间:2018-06-14 20:29:49

标签: python arrays numpy

我有:

test = np.random.randn(40,40,3)

我想做:

result = Repeat(test, 10)

因此result包含数组test重复10次,形状为:

(10, 40, 40, 3)

因此,创建一个带有新轴的张量,以容纳10份test。我也想尽可能高效地做到这一点。我怎么能和Numpy一起做这个?

谢谢!

3 个答案:

答案 0 :(得分:7)

您可以将np.repeat方法与np.newaxis一起使用:

import numpy as np

test = np.random.randn(40,40,3)
result = np.repeat(test[np.newaxis,...], 10, axis=0)
print(result.shape)
>> (10, 40, 40, 3)

答案 1 :(得分:3)

假设您希望将值复制10次,则只需stack 10个数组:

def repeat(arr, count):
    return np.stack([arr for _ in range(count)], axis=0)

axis=0实际上是默认值,所以这里并不是真的有必要,但我认为你在前面添加新轴更清楚了。

实际上,这与stack的示例完全相同:

>>> arrays = [np.random.randn(3, 4) for _ in range(10)]
>>> np.stack(arrays, axis=0).shape
(10, 3, 4)

乍一看,您可能认为repeattile会更合适。

但是repeat是关于重复现有的轴(或展平数组),因此您需要在{之前或之后} reshape。 (这同样有效,但我认为不那么简单。)

并且tile(假设您使用类似数组的reps - 标量reps,基本上repeat)是关于在所有方向上填写多维规格,比这个简单案例要复杂得多。

所有这些选项都同样有效。他们都将数据复制10次,这是昂贵的部分;任何内部处理,构建微小的中间对象等的成本都是无关紧要的。使速度更快的唯一方法是避免复制。您可能不想这样做。

但如果你这样做...要在10份副本之间共享行存储,你可能需要broadcast_to

def repeat(arr, count):
    return np.broadcast_to(arr, (count,)+arr.shape)

请注意broadcast_to实际上保证它不会复制,只是它返回某种只读视图,其中“广播数组的多个元素可能引用一个单个内存位置“。在实践中,它将避免复制。如果你真的需要保证由于某种原因(或者如果你想要一个可写的视图 - 这通常是一个可怕的想法,但也许你有充分的理由......),你必须下降到{{3} }:

def repeat(arr, count):
    shape = (count,) + arr.shape
    strides = (0,) + arr.strides
    return np.lib.stride_tricks.as_strided(
        arr, shape=shape, strides=strides, writeable=False)

请注意,as_strided的一半文档警告您可能不应该使用它,另一半警告您绝对不应该将它用于可写视图,所以...确保这是什么你想做之前。

答案 2 :(得分:0)

在创建正确副本的众多方法中,预分配+广播似乎最快。

import numpy as np

def f_pp_0():
    out = np.empty((10, *a.shape), a.dtype)
    out[...] = a
    return out

def f_pp_1():
    out = np.empty((10, *a.shape), a.dtype)
    np.copyto(out, a)
    return out

def f_oddn():
    return np.repeat(a[np.newaxis,...], 10, axis=0)

def f_abar():
    return np.stack([a for _ in range(10)], axis=0)

def f_arry():
    return np.array(10*[a])

from timeit import timeit

a = np.random.random((40, 40, 3))

for f in list(locals().values()):
    if callable(f) and f.__name__.startswith('f_'):
        print(f.__name__, timeit(f, number=100000)/100, 'ms')

示例运行:

f_pp_0 0.019641224660445003 ms
f_pp_1 0.019557840081397444 ms
f_oddn 0.01983011547010392 ms
f_abar 0.03257150553865358 ms
f_arry 0.02305851033888757 ms

但是差异很小,例如repeat几乎不会慢。