从TensorFlow的estimator.DNNClassifier获取权重

时间:2017-09-11 23:32:45

标签: python tensorflow

据我所知,DNNClassifier现已通过estimator.DNNClassifier进行培训。在使用contrib.learn.DNNClassifier训练之前,我们可以使用get_variable_names()提取权重。但是estimator.DNNClassifier中没有这样的方法。如果现在弃用contrib.learn,那么我们如何从新estimator.DNNClassifier获取权重?

1 个答案:

答案 0 :(得分:4)

显然,权重被称为“内核”。 (learnt from this question

例如,对于:

estimator = tf.estimator.DNNClassifier(
   feature_columns=feature_columns, 
   hidden_units=[2])

estimator.train(input_fn=input_fn_train)

您可以像这样使用get_variable_value

print(estimator.get_variable_value("dnn/hiddenlayer_0/kernel"))
print(estimator.get_variable_value("dnn/hiddenlayer_0/bias"))