如何在TensorFlow中访问protos中的值?

时间:2017-05-23 02:44:14

标签: tensorflow

我从tutorial看到我们可以做到这一点:

for node in tf.get_default_graph().as_graph_def().node: print node

在任意网络上完成后,我们会获得许多键值对。例如:

name: "conv2d_2/convolution"
op: "Conv2D"
input: "max_pooling2d/MaxPool"
input: "conv2d_1/kernel/read"
device: "/device:GPU:0"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "data_format"
  value {
    s: "NHWC"
  }
}
attr {
  key: "padding"
  value {
    s: "SAME"
  }
}
attr {
  key: "strides"
  value {
    list {
      i: 1
      i: 1
      i: 1
      i: 1
    }
  }
}
attr {
  key: "use_cudnn_on_gpu"
  value {
    b: true
  }
}

如何访问所有这些值并将它们放在Python列表中?具体来说,我们如何才能获得“步伐”。属性并将所有1转换为[1,1,1,1]?

1 个答案:

答案 0 :(得分:3)

TLDR:以下代码是您可能想要使用的代码:

for n in tf.get_default_graph().as_graph_def().node: if 'strides' in n.attr.keys(): print n.name, [int(a) for a in n.attr['strides'].list.i] if 'shape' in n.attr.keys(): print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

这样做的诀窍是了解protobufs是什么。我们来看看上面提到的tutorial

首先,有一个声明:

  

for node in graph_def.node

     

每个节点都是NodeDef对象,定义于   tensorflow /核心/框架/ node_def.proto。这些都是根本   TensorFlow图的构建块,每个图定义一个   操作及其输入连接。这是一个成员   NodeDef及其含义。

请注意node_def.proto中的以下内容:

  • 导入attr_value.proto。
  • 有name,op,input,device,attr等属性。具体来说,输入前面有repeated个术语。我们现在可以忽略这一点。

这与Python类完全相同,因此我们可以调用node.name,node.op,node.input,node.device,node.attr等。

我们现在想要访问的是node.attr中的内容。如果我们再次参考教程,它会指定:

  

这是一个包含节点所有属性的键/值存储。这些   是节点的永久属性,不会改变的东西   运行时,例如卷积过滤器的大小,或者值的   不断的操作。因为可以有这么多不同类型的   属性值,从字符串到整数,到张量值数组,   有一个单独的protobuf文件定义了数据结构   在tensorflow / core / framework / attr_value.proto中保存它们。

     

每个属性都有唯一的名称字符串和预期的属性   在定义操作时列出。如果属性不是   存在于节点中,但它在操作中列出了默认值   定义,在创建图表时使用默认值。

     

您可以通过调用node.name,node.op来访问所有这些成员,   Python中的等等。存储在GraphDef中的节点列表是完整的   模型架构的定义。

由于这是一个键值存储,我们可以调用n.attr.keys()来查看此属性具有的键列表。如果有这样的密钥,我们可以进一步调用n.attr['strides']来访问大步。当我们尝试打印时,我们会得到以下结果:

list {
  i: 1
  i: 2
  i: 2
  i: 1
}

这就是它开始变得混乱的地方,因为我们可能会尝试list(n.attr['strides'])或类似的东西。如果我们查看attr_value.proto,我们就可以了解正在发生的事情。我们看到它是oneof value,在这种情况下它是ListValue list,因此我们可以调用n.attr['strides'].list。如果我们打印出来,我们会得到以下结果:

i: 1
i: 1
i: 1
i: 1

我们接下来可能会尝试执行此操作:[a for a in n.attr['strides'].list][a.i for a in n.attr['strides'].list]。然而,没有任何作用。这是repeated是理解的重要术语。它基本上意味着有一个int64列表,你必须使用i属性访问它。然后做[int(a) for a in n.attr['strides'].list.i]给我们想要的东西,我们可以使用的Python列表:

[1, 1, 1, 1]