我正在使用tf.Estimator
来创建我的模型。这是培训一段时间,然后estimator.export_savedmodel
。由于我使用辍学训练,我担心在训练后直接进行出口会在做预测时适用辍学。
现在我拥有的是一个加载了tf.saved_model.loader.load
的模型。我想我可以从我加载模型的会话中获取图形定义。我可以在这里检查辍学的价值吗?
答案 0 :(得分:1)
事实证明,您可以检查图表中任何变量或常量的值。毕竟,这是导出模型的目的。
您应该有权访问加载模型的会话。
在这种情况下,您可以浏览图表中的所有节点,如this question中所述,并提取与丢失值对应的节点。如果您没有为其指定特定名称,则默认为name_space/dropout/keep_prob
。
dropout_nodes = [node for node in sess.graph_def.node if 'dropout' in node.name]
然后,您可以检查任何此类节点的值。就我而言,它看起来像这样:
name: "deep_bidirectional_lstm/dropout/keep_prob"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
这是一个protobuf消息。它表示操作是"Const"
,其值为tensor
类型DT_FLOAT
,没有形状,值为1.0
您可以使用protobuf API将其解析为字典,或者如果只想要最后一部分,您可以像这样提取它:
print(dropout_nodes[0].attr.get('value').tensor.float_val[0])
1.0
所以你很安全,你的辍学是1 :)
大约1年后再回到这里,我意识到有一点混乱:当你说.attr.get('value')
时,'value'
指的是基于他们的{key
得到的两个属性中的哪一个。 {1}}:"dtype"
或"value"
。它与每个属性的value
属性无关。