在Tensorflow中C ++相当于python:tf.Graph.get_tensor_by_name()?

时间:2016-09-07 20:41:16

标签: c++ tensorflow

Tensorflow中的python:tf.Graph.get_tensor_by_name(name)的C ++等价物是什么?谢谢!

以下是我尝试运行的代码,但我得到一个空的output

Status status = NewSession(SessionOptions(), &session); // create new session
ReadBinaryProto(tensorflow::Env::Default(), model, &graph_def); // read Graph
session->Create(graph_def); // add Graph to Tensorflow session 
std::vector<tensorflow::Tensor> output; // create Tensor to store output
std::vector<string> vNames; // vector of names for required graph nodes
vNames.push_back("some_name"); // I checked names and they are presented in loaded Graph

session->Run({}, vNames, {}, &output); // ??? As a result I have empty output

3 个答案:

答案 0 :(得分:3)

有一种方法可以直接从graph_def获取神经节点。 如果你只想要节点的形状\类型:“some_name”:

void readPB(GraphDef & graph_def)
{

    int i;
    for (i = 0; i < graph_def.node_size(); i++)
    {
        if (graph_def.node(i).name() == "inputx")
        {
            graph_def.node(i).PrintDebugString();
        }
    }
}

结果:

name: "inputx"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: -1
      }
      dim {
        size: 5120
      }
    }
  }
}

尝试节点的成员功能并获取信息。

答案 1 :(得分:2)

your comment开始,听起来您正在使用C ++ tensorflow::Session API,它将图表表示为GraphDef协议缓冲区。此API中有无等效tf.Graph.get_tensor_by_name()

不是将已键入的tf.Tensor对象传递给Session::Run(),而是传递string张量的<NODE NAME>:<N>名称,其格式为<NODE NAME>,其中NodeDef.name匹配一个GraphDef中的<N>值,session->Run()是一个整数,对应于您要获取的节点的输出索引。

你问题中的代码看起来大致正确,但有两件事我建议:

  1. tensorflow::Status调用返回output值。如果在调用返回后"some_name"为空,则几乎可以肯定该调用返回了错误状态,并带有解释问题的消息。

  2. 您将"some_name:0"作为张量的名称传递给fetch,但它是节点的名称,而不是张量。此API可能需要您明确指定输出索引:尝试将其替换为scale_fill_identity

答案 2 :(得分:1)

如果有人感兴趣,这是使用tensorflow C ++ API从graph_def中提取任意传感器的形状的方法

vector<int64_t> get_shape_of_tensor(tensorflow::GraphDef graph_def, std::string name_tensor)
{
    vector<int64_t> tensor_shape;
    for (int i=0; i < graph_def.node_size(); ++i) {
        if (graph_def.node(i).name() == name_tensor) {
            auto node = graph_def.node(i);
            auto attr_map = node.attr();
            for (auto it=attr_map.begin(); it != attr_map.end(); it++) {
                auto key = it->first;
                auto value = it->second;
                if (value.has_shape()) {
                    auto shape = value.shape();
                    for (int i=0; i<shape.dim_size(); ++i) {
                        auto dim = shape.dim(i);
                        auto dim_size = dim.size();
                        tensor_shape.push_back(dim_size);
                    }
                }
            }
        }
    }
    return tensor_shape
}