CNN可学习参数的数量--Python / TensorFlow

时间:2017-11-15 14:35:02

标签: python tensorflow conv-neural-network

在TensorFlow中,我能做些什么来找出网络中学习参数的数量?

2 个答案:

答案 0 :(得分:6)

我没有注意到任何功能,但您仍然可以使用tf.trainable_variables():上的for循环计算自己

total_parameters = 0
for variable in tf.trainable_variables():
    variable_parameters = 1
    for dim in variable.get_shape():
        variable_parameters *= dim.value
    total_parameters += variable_parameters

print("Total number of trainable parameters: %d" % total_parameters)

答案 1 :(得分:2)

您可以使用一个简单的单线纸执行此操作:

#!/bin/bash
awk ' BEGIN {flag=1}

{
if ($0 ~ /Date:/) { flag=0;
}
if ($0 ~ /From/) {  flag=1;
}
if (flag==0) {print}
}

'

如果您需要更多细节,请使用以下帮助器函数查看所有可训练的参数:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])

它会向您显示以下信息:

def show_params():
  total = 0
  for v in tf.trainable_variables():
    dims = v.get_shape().as_list()
    num  = int(np.prod(dims))
    total += num
    print('  %s \t\t Num: %d \t\t Shape %s ' % (v.name, num, dims))
  print('\nTotal number of params: %d' % total)