如何重置tf.estimator.Estimator参数?

时间:2018-01-31 12:13:55

标签: tensorflow

我尝试了tf.Graph(),但无法通过new重置变量。代码如下:

with tf.Graph().as_default() as g:
    clf_ = tf.estimator.Estimator(model_fn=my_w2d.model_fn_wide2deep, params=param, model_dir="/Users/zhouliaoming/data/credit_dnn/model_retrain/rm_gene_v2_sall/")
    with tf.name_scope("rewrite"):
        clf2 = tf.estimator.Estimator(model_fn=my_w2d.model_fn_wide2deep, params=param, model_dir="/Users/zhouliaoming/data/credit_dnn/model_retrain/genev2_s0/")
    out_bias = tf.get_variable("output_0/bias")
    out_b_rew = tf.get_variable("rewrite/output_0/bias")
    vars_ = clf_.get_variable_names()   ## only has clf_.get_variable_values()
    print("vars: %r\n output_0/bias: %r\ntrain-vars: %r" % (vars_, clf_.get_variable_value('output_0/bias'), tf.contrib.framework.get_trainable_variables()))
    print("before rewrite: out_bias: %r, out_b_rew: %r" % (out_bias.eval(), out_b_rew.eval()))
    out_b_rew.assing(out_bias)
    print("after rewrite: out_bias: %r, out_b_rew: %r" % (out_bias.eval(), out_b_rew.eval()))

它只是返回错误:

Traceback (most recent call last):
  File "tf_utils.py", line 31, in <module>
    out_bias = tf.get_variable("output_0/bias")
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1262, in get_variable
    constraint=constraint)
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1097, in get_variable
    constraint=constraint)
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 435, in get_variable
    constraint=constraint)
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 404, in _true_getter
    use_resource=use_resource, constraint=constraint)
  File "/Users/zhouliaoming/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 764, in _get_single_variable
    "but instead was %s." % (name, shape))
ValueError: Shape of a new variable (output_0/bias) must be fully defined, but instead was <unknown>.

===============旧信息剪切线=========

我通过model_fn处理程序定义了一个tf.estimator.Estimator模型A. 我想用与ckpt文件相同的旧模型参数来更改模型A的参数。 我尝试获取模型A的图形,然后在Graph中获取参数的变量,然后通过我的旧模型参数进行分配。 希望一些建议! 非常感谢!

1 个答案:

答案 0 :(得分:1)

有很多方法可以做到这一点,具体取决于您可以使用的内容。例如,如果您拥有两个模型中的代码和检查点,则可以创建两个单独的图形(with tf.Graph() as g)将两个检查点加载到其中,从一个图形中读取变量值并将其分配给另一个图形中的变量

如果您确切地知道要在一个检查点中读取的变量,则可以只恢复它(Saver.restore获取要恢复的变量列表),或者您可以使用CheckpointReader等工具读取它