如何在TensorFlow C ++ API中获得占位符大小?

时间:2018-12-26 11:10:17

标签: c++ tensorflow

我想使用C ++加载TensorFlow模型。我想知道模型输入的大小,即模型中的占位符。

我用谷歌搜索这个问题,但我只是在stackoverflow中找到此链接:

  

C++ equivalent of python: tf.Graph.get_tensor_by_name() in Tensorflow?

虽然我可以得到节点,但是tensorflow文档没有告诉我如何访问节点的大小。有人知道吗?

非常感谢您!

1 个答案:

答案 0 :(得分:0)

好,经过多次尝试。我找到了一种解决方法,它可能很棘手,但效果很好。

首先,我们可以使用以下代码获取占位符节点:

GraphDef mygd = graph_def.graph_def();
for (int i = 0; i < mygd.node_size(); i++)
{
    if (mygd.node(i).name() == input_name)
    {
        auto node = mygd.node(i);
    }
}

然后通过NodeDef.pd.h(tensorflow / core / framework / node_def.pb.h),我们可以通过以下代码获取AttrValue:

auto attr = node.attr();

然后通过attr_value.cc(tensorflow / core / framework / attr_value.cc),我们可以通过以下代码获取形状attr值:

tensorflow::AttrValue shape = attr["shape"];

和形状AttrValue是用于存储形状信息的结构。我们可以通过tensorflow / core / framework / attr_value_util.h

中的SummarizeAttrValue函数获取详细信息。
string size_summary = SummarizeAttrValue(shape);

然后我们可以获得形状的字符串格式,如下所示:

[?,1024]