导出模型中的tensorflow批量归一化

时间:2018-06-22 12:24:21

标签: tensorflow

我正在尝试在第三方创建的张量流模型中提取模型结构和权重。

我对如何导出批处理规范(BN)层感到困惑。因此,我正在加载模型并以标准方式检查节点及其值:

with tf.Session() as sess:
    with gfile.FastGFile(os.path.expanduser(tf_frozen_model_path), 'rb') as f:
        gd = tf.GraphDef()
        gd.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(gd, name='')

        nodes = [n for n in gd.node]

        wts = [n for n in nodes if n.op == 'Const']  # model frozen
        for n in wts:
            tensor = n.attr['value'].tensor
            print("Name, shape of the node: ", n.name, tensor.tensor_shape)
            print("Value - ")
            print(tensor_util.MakeNdarray(n.attr['value'].tensor))

现在,对于BN层,大多数参数都是有意义的。因此,典型的BN层如下所示:

// This is the beta parameter
Name, shape of the node:  conv4_3/1/conv4_3/1/bn/beta dim {
  size: 128
}
...values follow

// Moving average
Name, shape of the node:  conv4_3/1/conv4_3/1/bn/moving_mean dim {
  size: 128
}
... values follow

// Moving variance
Name, shape of the node:  conv4_3/1/conv4_3/1/bn/moving_variance dim {
  size: 128
}
... values follow

Name, shape of the node:  conv4_3/1/conv4_3/1/bn/Const dim {
  size: 128
}

Value - 
[1. 1. 1. ...] // Value of all ones..

我对这个Const节点是什么以及为什么要创建它感到困惑?它与批量归一化的scale参数有关吗?

0 个答案:

没有答案