组合图形:C ++有一个TensorFlow import_graph_def等价物吗?

时间:2018-03-26 11:22:59

标签: python c++ pointers tensorflow merge

我需要使用自定义输入和输出层扩展导出的模型。我发现这可以轻松完成:

with tf.Graph().as_default() as g1: # actual model
    in1 = tf.placeholder(tf.float32,name="input")
    ou1 = tf.add(in1,2.0,name="output")
with tf.Graph().as_default() as g2: # model for the new output layer
    in2 = tf.placeholder(tf.float32,name="input")
    ou2 = tf.add(in2,2.0,name="output")

gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()

with tf.Graph().as_default() as g_combined: #merge together
    x = tf.placeholder(tf.float32, name="actual_input") # the new input layer

    # Import gdef_1, which performs f(x).
    # "input:0" and "output:0" are the names of tensors in gdef_1.
    y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                             return_elements=["output:0"])

    # Import gdef_2, which performs g(y)
    z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                             return_elements=["output:0"])

sess = tf.Session(graph=g_combined)

print "result is: ", sess.run(z, {"actual_input:0":5}) #result is: 9

这很好用。

然而,我不是以任意形状传递数据集,而是需要将指针作为网络输入。问题是,我无法想到这个内部python(定义和传递指针)的任何解决方案,并且当使用C++ Api开发网络时,我无法找到与{ {1}}功能。

这在C ++中是否有不同的名称,或者是否有其他方法可以在C ++中合并两个图形/模型?

感谢您的任何建议

2 个答案:

答案 0 :(得分:2)

它并不像Python那么容易。

您可以使用以下内容加载GraphDef

#include <string>
#include <tensorflow/core/framework/graph.pb.h>
#include <tensorflow/core/platform/env.h>

tensorflow::GraphDef graph;
std::string graphFileName = "...";
auto status = tensorflow::ReadBinaryProto(
    tensorflow::Env::Default(), graphFileName, &graph);
if (!status.ok()) { /* Error... */ }

然后你可以用它来创建一个会话:

#include <tensorflow/core/public/session.h>

tensorflow::Session *newSession;
auto status = tensorflow::NewSession(tensorflow::SessionOptions(), &newSession);
if (!status.ok()) { /* Error... */ }
status = session->Create(graph);
if (!status.ok()) { /* Error... */ }

或者扩展现有图表:

status = session->Extend(graph);
if (!status.ok()) { /* Error... */ }

这样您就可以将多个GraphDef放入同一个图表中。但是,没有额外的工具来提取特定节点,也没有避免名称冲突 - 您必须自己找到节点,并且必须确保GraphDef没有冲突的操作名称。例如,我使用此函数查找名称与给定正则表达式匹配的所有节点,按名称排序:

#include <vector>
#include <regex>
#include <tensorflow/core/framework/node_def.pb.h>

std::vector<const tensorflow::NodeDef *> GetNodes(const tensorflow::GraphDef &graph, const std::regex &regex)
{
    std::vector<const tensorflow::NodeDef *> nodes;
    for (const auto &node : graph.node())
    {
        if (std::regex_match(node.name(), regex))
        {
            nodes.push_back(&node);
        }
    }
    std::sort(nodes.begin(), nodes.end(),
              [](const tensorflow::NodeDef *lhs, const tensorflow::NodeDef *rhs)
              {
                  return lhs->name() < rhs->name();
              });
    return nodes;
}

答案 1 :(得分:0)

这可以在C ++中通过直接操作要组合的两个图的GraphDef中的NodeDef来实现。基本算法是定义两个GraphDef,使用占位符作为第二个GraphDef的输入,并将它们重定向到第一个GraphDef的输出。这类似于通过将第二个电路的输入连接到第一个电路的输出来串联两个电路。

首先,定义示例GraphDef,以及用于观察GraphDef内部的实用程序。重要的是要注意,两个GraphDef中的所有节点都必须具有唯一的名称。

<mat-form-field [@error]="isError" 
                (@error.done)="isError = false"
                class="al-subscribe-form-field">
...
</mat-form>

现在,将创建两个GraphDef,并将第二个GraphDef的输入连接到第一个GraphDef的输出。这是通过在节点上进行迭代并标识输入为占位符的第一个操作节点并将这些输入重定向到第一个GraphDef的输出来完成的。然后将该节点以及所有后续节点添加到第一个GraphDef。结果是第一个GraphDef附加第二个GraphDef。

Status Panel::SampleFirst(GraphDef *graph_def) 
{
    Scope root = Scope::NewRootScope();
    Placeholder p1(root.WithOpName("p1"), DT_INT32);
    Placeholder p2(root.WithOpName("p2"), DT_INT32);
    Add add(root.WithOpName("add"), p1, p2);
    return root.ToGraphDef(graph_def);
}

Status Panel::SampleSecond(GraphDef *graph_def)
{
    Scope root = Scope::NewRootScope();
    Placeholder q1(root.WithOpName("q1"), DT_INT32);
    Placeholder q2(root.WithOpName("q2"), DT_INT32);
    Add sum(root.WithOpName("sum"), q1, q2);
    Multiply multiply(root.WithOpName("multiply"), sum, 4);
    return root.ToGraphDef(graph_def);
}

void Panel::ShowGraphDef(GraphDef &graph_def)
{
    for (int i = 0; i < graph_def.node_size(); i++) {
        NodeDef node_def = graph_def.node(i);
        cout << "NodeDef name is " << node_def.name() << endl;
        cout << "NodeDef op is " << node_def.op() << endl;
        for (const string& input : node_def.input()) {
            cout << "\t input: " << input << endl;
        }
    }
}

此特定图将采用两个输入2和3,并将它们加在一起。然后,将那个(5)的总和再次添加到第一个输入(2),然后乘以4,得到结果28。((2 + 3)+ 2)* 4 = 28。