获取张量流图中常量节点的值

时间:2018-03-13 03:06:04

标签: tensorflow

我正在使用tf.Estimator来创建我的模型。这是培训一段时间,然后estimator.export_savedmodel。由于我使用辍学训练,我担心在训练后直接进行出口会在做预测时适用辍学。

现在我拥有的是一个加载了tf.saved_model.loader.load的模型。我想我可以从我加载模型的会话中获取图形定义。我可以在这里检查辍学的价值吗?

1 个答案:

答案 0 :(得分:1)

事实证明,您可以检查图表中任何变量或常量的值。毕竟,这是导出模型的目的。

您应该有权访问加载模型的会话。 在这种情况下,您可以浏览图表中的所有节点,如this question中所述,并提取与丢失值对应的节点。如果您没有为其指定特定名称,则默认为name_space/dropout/keep_prob

dropout_nodes = [node for node in sess.graph_def.node if 'dropout' in node.name]

然后,您可以检查任何此类节点的值。就我而言,它看起来像这样:

name: "deep_bidirectional_lstm/dropout/keep_prob"
op: "Const"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_FLOAT
      tensor_shape {
      }
      float_val: 1.0
    }
  }
}

这是一个protobuf消息。它表示操作是"Const",其值为tensor类型DT_FLOAT,没有形状,值为1.0

您可以使用protobuf API将其解析为字典,或者如果只想要最后一部分,您可以像这样提取它:

print(dropout_nodes[0].attr.get('value').tensor.float_val[0])
1.0

所以你很安全,你的辍学是1 :)

大约1年后再回到这里,我意识到有一点混乱:当你说.attr.get('value')时,'value'指的是基于他们的{key得到的两个属性中的哪一个。 {1}}:"dtype""value"。它与每个属性的value属性无关。