变量名称不适用于init_from_checkpoint吗?

时间:2018-12-31 07:33:56

标签: variables tensorflow

请查看此玩具模型:

import tensorflow as tf
import os

if not os.path.isdir('./temp'):
    os.mkdir('./temp')


def create_and_save_varialbe(sess=tf.Session()):
    a = tf.get_variable("a", [])
    saver_a = tf.train.Saver({"a": a})
    init = tf.global_variables_initializer()
    sess.run(init)
    saver_a.save(sess, './temp/temp_model')
    a = sess.run(a)
    print('the initialized a is %f' % a)
    return a


def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b


def init_get_variable(sess=tf.Session()):
    c = tf.get_variable("c", shape=[])
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'c'})
    init = tf.global_variables_initializer()
    sess.run(init)
    c = sess.run(c)
    print(c)
    return c


a = create_and_save_varialbe()
b = init_variable()
c = init_get_variable()

函数init_get_varialbe有效,但函数init_variable不起作用。

  

ValueError:仅具有作用域名称的分配映射应映射到作用域   只有一个。应该是“ scope /”:“ other_scope /”。

在这种情况下,为什么由变量定义的变量名称不起作用,我该如何解决?

Tensorflow版本:1.12

1 个答案:

答案 0 :(得分:0)

这是因为Variable和get_variable之间的difference

有两种解决方法:

1)输入名称以外的变量。

def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': b})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

因为它是可变的tensorflow可以得到它directly

if _is_variable(current_var_or_name) or (
    isinstance(current_var_or_name, list)
    and all(_is_variable(v) for v in current_var_or_name)):
  var = current_var_or_name

否则它将获得变量from variable store

  store_vars = vs._get_default_variable_store()._vars 

但是由Variable定义的变量不在answer中所述的('varstore_key',)集合中。

然后2)您可以自己将其添加到集合中:

from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
    b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32) 
    store = _VariableStore()
    store._vars = {'b': b}
    ops.add_to_collection(('__variable_store',), store)
    tf.train.init_from_checkpoint('./temp/temp_model', 
            {'a': 'b'})
    init = tf.global_variables_initializer()
    sess.run(init)
    b = sess.run(b)
    print(b)
    return b

两个工作。