有Keras函数可以计算单位总数吗?

时间:2019-11-19 09:12:35

标签: keras

Keras具有 count_param() Python函数,用于计算人工神经网络(ANN)模型的可训练参数的总数。

model.count_params()

以同样的方式,是否存在Keras函数来计算ANN模型的单位总数?

1 个答案:

答案 0 :(得分:1)

看起来似乎没有简单的方法可以解决此问题。例如,输入层将返回一个元组列表,其中(大多数?)其他层仅返回一个元组。但是以下功能在大多数情况下应该起作用。

很明显,此函数接受一个模型并返回两个输出。

  • 输出单位总数
  • 将每个图层的输出单位作为列表

让我知道它是否不适用于任何特定情况(因为我尚未对此进行详尽的测试)

from functools import reduce
from itertools import chain
import operator 
def count_units(model):
  tot_out = 0
  out_list = []
  for lyr in model.layers:
    if lyr.trainable:
      # This is to tackle any layers that have the output shape as a list of tuples (e.g Input layer)
      if isinstance(lyr.output_shape, list):
        curr_out = reduce(operator.mul, chain(*[s[1:] for s in lyr.output_shape]), 1)
      # This is to tackle other layers like Dense and Conv2D
      elif isinstance(lyr.output_shape, tuple):
        curr_out = reduce(operator.mul, lyr.output_shape[1:], 1)
      else:
        raise TypeError
      tot_out += curr_out
      out_list.append(curr_out)
  return tot_out, out_list

print(count_units(model))