tensorflow C ++等效于tf.trainable_variables()?

时间:2019-03-01 09:24:11

标签: c++ tensorflow

我的目标是从C ++ API获得具有所有可训练变量的名称列表。在Python中,使用tf.trainable_variables()可以解决此问题。

到目前为止,我尝试了这种方法。 我有一个tensorflow :: GraphDef对象,我可以看到所有这样创建的节点:

for (int i = 0; i < graphDef.node_size(); i++) {
    graphDef.node(i).PrintDebugString();
}

太好了。这些节点中有一些是指可训练的变量,但是我不知道如何获得该信息/或者是否可能。

1 个答案:

答案 0 :(得分:1)

该信息在GraphDef对象中不可用。 tf.trainable_variables只是返回带有键tf.GraphKeys.TRAINABLE_VARIABLES的图形集合,但是图形集合不会保存到GraphDef,而只会保存到MetaGraphDef(请参阅Exporting and Importing a MetaGraph)。如果要从C ++访问保存的图形中的可训练变量,则必须导出和导入MetaGraph,或者也许使用一致的命名方案来区分它们。

请注意,顺便说一下,在TensorFlow 2.x中将不赞成使用图形集合。有关更多信息,请参见Deprecating collections