Sum数组数组中每个数组的所有元素

时间:2017-08-20 15:11:28

标签: python numpy

我有这本词典:

import numpy as np
dict={'W1': np.array([[ 1.62434536, -0.61175641, -0.52817175], 
                     [-1.07296862,  0.86540763, -2.3015387 ]]), 
     'b1': np.array([[ 1.74481176], 
                     [-0.7612069 ]]), 
     'W2': np.array([[ 0.3190391 , -0.24937038], 
                     [ 1.46210794, -2.06014071], 
                     [-0.3224172 , -0.38405435]]), 
     'b2': np.array([[ 1.13376944], 
                     [-1.09989127], 
                     [-0.17242821]]), 
     'W3': np.array([[-0.87785842,  0.04221375,  0.58281521]]), 
     'b3': np.array([[-1.10061918]])}

我需要在平方后对W1,W2,W3的所有元素求和,每次都是三个。

我用它来提取一个带有键W(i)

的列表
l=[v for k, v in dict.items() if 'W' in k]

我怎样才能得到每个数组中平方元素的总和? 当我分别拍摄每个阵列时:

 np.sum(np.square(l[0]) to get 10.4889815722 for l[0]

我不知道如何一次性总结它们

2 个答案:

答案 0 :(得分:1)

您可以使用字典理解简单地提取所有值的总和:

>>> res = {key: np.square(arr).sum() for key, arr in dct.items()}  # you could also use if 'W' in key here too.
>>> res
{'W1': 10.48898156439229,
 'W2': 6.7973615015702658,
 'W3': 1.1120909752613031,
 'b1': 3.6238040224419072,
 'b2': 2.5249254365039309,
 'b3': 1.2113625793838725}

鉴于字典是无序的,因为字典可能更好(例如可以通过res['W1']访问),因为否则列表元素将按任意顺序排列(或者您需要在将它们放入之前对其进行排序列表)。

汇总所有W*值:

>>> sum(v for k, v in res.items() if 'W' in k)  # normal sum this time but would also work with np.sum!
18.398434041223858

答案 1 :(得分:0)

我不确定我找到了你,但这就是我认为你在寻找的东西:

>>> import numpy as np
>>> data = [np.array([1.62434536, -0.61175641, -0.52817175]), np.array([-1.07296862, 0.86540763, -2.3015387])]
>>> [np.sum(arr) for arr in data]
[0.48441719999999999, -2.5090996900000002]