从字典python中有效地dstack数组

时间:2012-07-08 22:13:29

标签: python numpy dictionary

我有一个按日期键入的字典,并填充了属性为numpy.array的类。我想使用np.dstack从字典中的所有数组创建一个大数组。我目前的代码是这样的:

import numpy as np
#PARTS is my dictionary
#the .partposit is the attribute that is an array of shape (50000, 12)
ks = sorted(PARTS.keys())
p1 = PARTS[ks[0]].partposit
for k in ks[1:]:
    p1 = np.dstack((p1, PARTS[k].partposit))

我的结果如我所料:

In [67]: p1.shape
Out[67]: (50000, 12, 163)

然而,它很慢。有没有更有效的方法来做到这一点?

2 个答案:

答案 0 :(得分:3)

你可以试试这个:

>>> import numpy as np
>>> class A:
...     def __init__(self, values):
...         self.partposit = values
... 
>>> PARTS = dict((index, A(np.zeros((50000, 12)))) for index in xrange(163))
>>> p1 = np.dstack((PARTS[k].partposit for k in sorted(PARTS.keys())))
>>> p1.shape
(50000, 12, 163)
>>> 

将它堆叠在我的机器上花了几秒钟。

>>> import timeit
>>> timeit.Timer('p1 = np.dstack((PARTS[k].partposit for k in sorted(PARTS.keys())))', "from __main__ import np, PARTS").timeit(number = 1)
2.1245520114898682

numpy.dstack接受一系列数组并将它们堆叠在一起,如果我们只是给它列表而不是自己连续堆叠它们会更快。

  

numpy.dstack(TUP)
   

按顺序深度(沿第三轴)堆叠数组。       采用一系列数组并沿第三轴堆叠它们以形成单个数组。

http://docs.scipy.org/doc/numpy/reference/generated/numpy.dstack.html

我也很好奇看你的方法会有多长:

>>> import timeit
>>> setup = """
... import numpy as np
... #PARTS is my dictionary
... #the .partposit is the attribute that is an array of shape (50000, 12)
... 
... class A:
...     def __init__(self, values):
...         self.partposit = values
... 
... PARTS = dict((index, A(np.zeros((50000, 12)))) for index in xrange(163))
... ks = sorted(PARTS.keys())
... """
>>> stack = """
... p1 = PARTS[ks[0]].partposit
... for k in ks[1:]:
...     p1 = np.dstack((p1, PARTS[k].partposit))
... """
>>> timeit.Timer(stack, setup).timeit(number = 1)
67.69684886932373

哎哟!

>>> numpy.__version__
'1.6.1'

$ python --version
Python 2.6.1

我希望这会有所帮助。

答案 1 :(得分:0)

此行创建一个新列表(浅层副本),这是不必要的开销:

for k in ks[1:]:

更有效的方法是:

itks =iter(ks)
next(itks)
for k in itks:

此外,您可以使用以下方法消除重复查找:

entries = iter(sorted(((k, v.partposit) for k,v in PARTS.iteritems()), key=lambda(k,v):k))
p1 = next(entries)[1]
for k,v in entries: 
    p1 = np.dstack((p1, v))

这会使事情略微加快,因为它消除了dict中的复制和重复查找(虽然它是恒定时间,但不是免费的)。