在TensorFlow中,我能做些什么来找出网络中学习参数的数量?
答案 0 :(得分:6)
我没有注意到任何功能,但您仍然可以使用tf.trainable_variables():
上的for循环计算自己
total_parameters = 0
for variable in tf.trainable_variables():
variable_parameters = 1
for dim in variable.get_shape():
variable_parameters *= dim.value
total_parameters += variable_parameters
print("Total number of trainable parameters: %d" % total_parameters)
答案 1 :(得分:2)
您可以使用一个简单的单线纸执行此操作:
#!/bin/bash
awk ' BEGIN {flag=1}
{
if ($0 ~ /Date:/) { flag=0;
}
if ($0 ~ /From/) { flag=1;
}
if (flag==0) {print}
}
'
如果您需要更多细节,请使用以下帮助器函数查看所有可训练的参数:
np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
它会向您显示以下信息:
def show_params():
total = 0
for v in tf.trainable_variables():
dims = v.get_shape().as_list()
num = int(np.prod(dims))
total += num
print(' %s \t\t Num: %d \t\t Shape %s ' % (v.name, num, dims))
print('\nTotal number of params: %d' % total)