我正在尝试在第三方创建的张量流模型中提取模型结构和权重。
我对如何导出批处理规范(BN)层感到困惑。因此,我正在加载模型并以标准方式检查节点及其值:
with tf.Session() as sess:
with gfile.FastGFile(os.path.expanduser(tf_frozen_model_path), 'rb') as f:
gd = tf.GraphDef()
gd.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(gd, name='')
nodes = [n for n in gd.node]
wts = [n for n in nodes if n.op == 'Const'] # model frozen
for n in wts:
tensor = n.attr['value'].tensor
print("Name, shape of the node: ", n.name, tensor.tensor_shape)
print("Value - ")
print(tensor_util.MakeNdarray(n.attr['value'].tensor))
现在,对于BN层,大多数参数都是有意义的。因此,典型的BN层如下所示:
// This is the beta parameter
Name, shape of the node: conv4_3/1/conv4_3/1/bn/beta dim {
size: 128
}
...values follow
// Moving average
Name, shape of the node: conv4_3/1/conv4_3/1/bn/moving_mean dim {
size: 128
}
... values follow
// Moving variance
Name, shape of the node: conv4_3/1/conv4_3/1/bn/moving_variance dim {
size: 128
}
... values follow
Name, shape of the node: conv4_3/1/conv4_3/1/bn/Const dim {
size: 128
}
Value -
[1. 1. 1. ...] // Value of all ones..
我对这个Const
节点是什么以及为什么要创建它感到困惑?它与批量归一化的scale
参数有关吗?