下面的一个小程序会创建一个简单的tf图。我需要遍历图表,打印有关节点的信息。
假设每个图都有一个根(或区分节点)是正确的吗?我相信这个图有3个节点,我听说边是张量。
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include"tensorflow/c/c_api.h"
TF_Graph* g;
TF_Status* s;
#define CHECK_OK(x) if(TF_OK != TF_GetCode(s))return printf("%s\n",TF_Message(s)),(void*)0
TF_Tensor* FloatTensor2x2(const float* values) {
const int64_t dims[2] = {2, 2};
TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, dims, 2, sizeof(float) * 4);
memcpy(TF_TensorData(t), values, sizeof(float) * 4);
return t;
}
TF_Operation* FloatConst2x2(TF_Graph* graph, TF_Status* s, const float* values, const char* name) {
TF_Tensor* tensor=FloatTensor2x2(values);
TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", tensor, s);
if (TF_GetCode(s) != TF_OK) return 0;
TF_SetAttrType(desc, "dtype", TF_FLOAT);
TF_Operation* op = TF_FinishOperation(desc, s);
CHECK_OK(s);
return op;
}
TF_Operation* MatMul(TF_Graph* graph, TF_Status* s, TF_Operation* l, TF_Operation* r, const char* name,
char transpose_a, char transpose_b) {
TF_OperationDescription* desc = TF_NewOperation(graph, "MatMul", name);
if (transpose_a) {
TF_SetAttrBool(desc, "transpose_a", 1);
}
if (transpose_b) {
TF_SetAttrBool(desc, "transpose_b", 1);
}
TF_AddInput(desc,(TF_Output){l, 0});
TF_AddInput(desc,(TF_Output){r, 0});
TF_Operation* op = TF_FinishOperation(desc, s);
CHECK_OK(s);
return op;
}
TF_Graph* BuildSuccessGraph(TF_Output* inputs, TF_Output* outputs) {
// |
// z|
// |
// MatMul
// / \
// ^ ^
// | |
// x Const_0 y Const_1
//
float const0_val[] = {1.0, 2.0, 3.0, 4.0};
float const1_val[] = {1.0, 0.0, 0.0, 1.0};
TF_Operation* const0 = FloatConst2x2(g, s, const0_val, "Const_0");
TF_Operation* const1 = FloatConst2x2(g, s, const1_val, "Const_1");
TF_Operation* matmul = MatMul(g, s, const0, const1, "MatMul",0,0);
inputs[0] = (TF_Output){const0, 0};
inputs[1] = (TF_Output){const1, 0};
outputs[0] = (TF_Output){matmul, 0};
CHECK_OK(s);
return g;
}
int main(int argc, char const *argv[]) {
g = TF_NewGraph();
s = TF_NewStatus();
TF_Output inputs[2],outputs[1];
BuildSuccessGraph(inputs,outputs);
/* HERE traverse g -- maybe with {inputs,outputs} -- to print the graph */
fprintf(stdout, "OK\n");
}
如果有人可以帮助我们使用哪些功能来获取有关图表的信息,我们将不胜感激。
答案 0 :(得分:2)
来自c_api.h:
// Iterate through the operations of a graph. To use:
// size_t pos = 0;
// TF_Operation* oper;
// while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
// DoSomethingWithOperation(oper);
// }
TF_CAPI_EXPORT extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph,
size_t* pos);
注意这只返回操作,并没有定义从一个节点(Operation)导航到下一个节点的方法 - 这个边缘关系存储在节点本身(作为指针)。