如何使用C API遍历Tensorflow图?

时间:2018-04-17 00:54:18

标签: tensorflow machine-learning

下面的一个小程序会创建一个简单的tf图。我需要遍历图表,打印有关节点的信息。

假设每个图都有一个根(或区分节点)是正确的吗?我相信这个图有3个节点,我听说边是张量。

#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include"tensorflow/c/c_api.h"

TF_Graph* g;
TF_Status* s;

#define CHECK_OK(x) if(TF_OK != TF_GetCode(s))return printf("%s\n",TF_Message(s)),(void*)0

TF_Tensor* FloatTensor2x2(const float* values) {
  const int64_t dims[2] = {2, 2};
  TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
  memcpy(TF_TensorData(t), values, sizeof(float) * 4);
  return t;
}

TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, const float* values, const char* name) {
  TF_Tensor* tensor=FloatTensor2x2(values);
  TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
  TF_SetAttrTensor(desc, "value", tensor, s);
  if (TF_GetCode(s) != TF_OK) return 0;
  TF_SetAttrType(desc, "dtype", TF_FLOAT);
  TF_Operation* op = TF_FinishOperation(desc, s);
  CHECK_OK(s);
  return op;
}

TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, TF_Operation* r, const char* name,
                     char transpose_a, char transpose_b) {
  TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
  if (transpose_a) {
    TF_SetAttrBool(desc, "transpose_a", 1);
  }
  if (transpose_b) {
    TF_SetAttrBool(desc, "transpose_b", 1);
  }
  TF_AddInput(desc,(TF_Output){l, 0});
  TF_AddInput(desc,(TF_Output){r, 0});
  TF_Operation* op = TF_FinishOperation(desc, s);
  CHECK_OK(s);
  return op;
}

TF_Graph* BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
  //            |
  //           z|
  //            |
  //          MatMul
  //         /       \
  //        ^         ^
  //        |         |
  //    x Const_0  y Const_1
  //
  float const0_val[] = {1.0, 2.0, 3.0, 4.0};
  float const1_val[] = {1.0, 0.0, 0.0, 1.0};
  TF_Operation* const0 = FloatConst2x2(g, s, const0_val, "Const_0");
  TF_Operation* const1 = FloatConst2x2(g, s, const1_val, "Const_1");
  TF_Operation* matmul = MatMul(g, s, const0, const1, "MatMul",0,0);
  inputs[0] = (TF_Output){const0, 0};
  inputs[1] = (TF_Output){const1, 0};
  outputs[0] = (TF_Output){matmul, 0};
  CHECK_OK(s);
  return g;
}

int main(int argc, char const *argv[]) {
  g = TF_NewGraph();
  s = TF_NewStatus();

  TF_Output inputs[2],outputs[1];
  BuildSuccessGraph(inputs,outputs);

  /* HERE traverse g -- maybe with {inputs,outputs} -- to print the graph */

  fprintf(stdout, "OK\n");
}

如果有人可以帮助我们使用哪些功能来获取有关图表的信息,我们将不胜感激。

1 个答案:

答案 0 :(得分:2)

来自c_api.h:

// Iterate through the operations of a graph.  To use:
// size_t pos = 0;
// TF_Operation* oper;
// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
//   DoSomethingWithOperation(oper);
// }
TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph,
                                                      size_t* pos);

注意这只返回操作,并没有定义从一个节点(Operation)导航到下一个节点的方法 - 这个边缘关系存储在节点本身(作为指针)。