请查看此玩具模型:
import tensorflow as tf
import os
if not os.path.isdir('./temp'):
os.mkdir('./temp')
def create_and_save_varialbe(sess=tf.Session()):
a = tf.get_variable("a", [])
saver_a = tf.train.Saver({"a": a})
init = tf.global_variables_initializer()
sess.run(init)
saver_a.save(sess, './temp/temp_model')
a = sess.run(a)
print('the initialized a is %f' % a)
return a
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'b'})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
def init_get_variable(sess=tf.Session()):
c = tf.get_variable("c", shape=[])
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'c'})
init = tf.global_variables_initializer()
sess.run(init)
c = sess.run(c)
print(c)
return c
a = create_and_save_varialbe()
b = init_variable()
c = init_get_variable()
函数init_get_varialbe有效,但函数init_variable不起作用。
ValueError:仅具有作用域名称的分配映射应映射到作用域 只有一个。应该是“ scope /”:“ other_scope /”。
在这种情况下,为什么由变量定义的变量名称不起作用,我该如何解决?
Tensorflow版本:1.12
答案 0 :(得分:0)
这是因为Variable和get_variable之间的difference。
有两种解决方法:
1)输入名称以外的变量。
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': b})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
因为它是可变的tensorflow可以得到它directly:
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
否则它将获得变量from variable store:
store_vars = vs._get_default_variable_store()._vars
但是由Variable定义的变量不在answer中所述的('varstore_key',)
集合中。
然后2)您可以自己将其添加到集合中:
from tensorflow.python.ops.variable_scope import _VariableStore
from tensorflow.python.framework import ops
def init_variable(sess=tf.Session()):
b = tf.Variable(tf.constant(1.0, shape=[]), name="b", dtype=tf.float32)
store = _VariableStore()
store._vars = {'b': b}
ops.add_to_collection(('__variable_store',), store)
tf.train.init_from_checkpoint('./temp/temp_model',
{'a': 'b'})
init = tf.global_variables_initializer()
sess.run(init)
b = sess.run(b)
print(b)
return b
两个工作。