计算get_weights()输出中非零元素数量的最快方法

时间:2019-06-21 19:09:54

标签: numpy keras neural-network

我想计算神经网络权重中非零值的数量。

我尝试了以下代码,但获得了ValueError。这可能是由于每个数组具有不同形状的原因。

h = model.get_weights()  # return a list of numpy arrays
merged_h = []
for l in h:
    merged_h += l
nzcounts = np.count_nonzero(merged_h)

ValueError: operands could not be broadcast together with shapes (0,) (3,3,3,32) 

我想知道是否还有其他方法可以计算get_weights()的输出中非零元素的数量?谢谢!

1 个答案:

答案 0 :(得分:1)

本质上,问题在于model.get_weights()返回数组列表。我认为最简单的方法是将np.count_nonzero()分别应用于这些数组中的每个数组,然后对结果求和。

np.sum([np.count_nonzero(x) for x in model.get_weights()])