我想使用C ++加载TensorFlow模型。我想知道模型输入的大小,即模型中的占位符。
我用谷歌搜索这个问题,但我只是在stackoverflow中找到此链接:
C++ equivalent of python: tf.Graph.get_tensor_by_name() in Tensorflow?
虽然我可以得到节点,但是tensorflow文档没有告诉我如何访问节点的大小。有人知道吗?
非常感谢您!
答案 0 :(得分:0)
好,经过多次尝试。我找到了一种解决方法,它可能很棘手,但效果很好。
首先,我们可以使用以下代码获取占位符节点:
GraphDef mygd = graph_def.graph_def();
for (int i = 0; i < mygd.node_size(); i++)
{
if (mygd.node(i).name() == input_name)
{
auto node = mygd.node(i);
}
}
然后通过NodeDef.pd.h(tensorflow / core / framework / node_def.pb.h),我们可以通过以下代码获取AttrValue:
auto attr = node.attr();
然后通过attr_value.cc(tensorflow / core / framework / attr_value.cc),我们可以通过以下代码获取形状attr值:
tensorflow::AttrValue shape = attr["shape"];
和形状AttrValue是用于存储形状信息的结构。我们可以通过tensorflow / core / framework / attr_value_util.h
中的SummarizeAttrValue函数获取详细信息。string size_summary = SummarizeAttrValue(shape);
然后我们可以获得形状的字符串格式,如下所示:
[?,1024]