在Tensorflow中,使用c ++获取图形中所有张量的名称

时间:2017-11-20 06:38:51

标签: c++ c++11 tensorflow

我已加载" .pb"使用c ++的模型 我希望打印所有的模型操作。

例如:下面的.pb文件中的图层:

node {
  name: "add"
  op: "Add"
  input: "MatMul"
  input: "bias/read"
  attr {
    key: "T"
    value {
    type: DT_FLOAT
    }
  }
}
node {
  name: "output_TT"
  op: "Softmax"
  input: "add"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }  
}

我想获取名称(即add和output_TT)并使用 tensorflow 库使用c ++显示它们。

我没有使用bazel来建造;相反,我用一些自定义来执行inbuild makefile。

1 个答案:

答案 0 :(得分:4)

我按照步骤

获得了输出
int node_count = graph_def.node_size();
for (int i = 0; i < node_count; i++)
{
        auto n = graph_def.node(i);
        cout<<"Names : "<< n.name() <<endl;

}