为什么tf.get_variable('test')返回名称为test_1的变量?

时间:2018-09-27 11:02:15

标签: python tensorflow

我用tf.Variable创建了一个tensorflow变量。我想知道为什么我用相同的名称调用tf.get_variable时不会引发异常,而是用递增的名称创建一个新变量吗?

import tensorflow as tf

class QuestionTest(tf.test.TestCase):

    def test_version(self):
        self.assertEqual(tf.__version__, '1.10.1')

    def test_variable(self):
        a = tf.Variable(0., trainable=False, name='test')
        self.assertEqual(a.name, "test:0")

        b = tf.get_variable('test', shape=(), trainable=False)
        self.assertEqual(b.name, "test_1:0")

        self.assertNotEqual(a, b, msg='`a` is not `b`')

        with self.assertRaises(ValueError) as ecm:
            tf.get_variable('test', shape=(), trainable=False)
        exception = ecm.exception
        self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")

1 个答案:

答案 0 :(得分:2)

这是因为tf.Variable是一种低级方法,将创建的变量存储在GLOBALS(或LOCALS)集合中,而tf.get_variable通过将它们创建的变量存储在变量存储中来保留已创建的变量。 / p>

第一次调用tf.Variable时,创建的变量不会添加到变量存储中,使您认为尚未创建名称为"test"的变量。

因此,当您以后调用tf.get_variable("test")时,它将查看变量存储,发现其中没有名称为"test"的变量。
因此它将调用tf.Variable,这将创建一个变量名,"test_1"以递增的名称"test"存储在变量存储区中,键为import tensorflow as tf class AnswerTest(tf.test.TestCase): def test_version(self): self.assertEqual(tf.__version__, '1.10.1') def test_variable_answer(self): """Using the default variable scope""" # Let first check the __variable_store and the GLOBALS collections. self.assertListEqual(tf.get_collection(("__variable_store",)), [], "No variable store.") self.assertListEqual(tf.global_variables(), [], "No global variables") a = tf.Variable(0., trainable=False, name='test') self.assertEqual(a.name, "test:0") self.assertListEqual(tf.get_collection(("__variable_store",)), [], "No variable store.") self.assertListEqual(tf.global_variables(), [a], "but `a` is in global variables.") b = tf.get_variable('test', shape=(), trainable=False) self.assertNotEqual(a, b, msg='`a` is not `b`') self.assertEqual(b.name, "test_1:0", msg="`b`'s name is not 'test'.") self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0, "There is now a variable store.") var_store = tf.get_collection(("__variable_store",))[0] self.assertDictEqual(var_store._vars, {"test": b}, "and variable `b` is in it.") self.assertListEqual(tf.global_variables(), [a, b], "while `a` and `b` are in global variables.") with self.assertRaises(ValueError) as exception_context_manager: tf.get_variable('test', shape=(), trainable=False) exception = exception_context_manager.exception self.assertStartsWith(str(exception), "Variable test already exists, disallowed.")

    def test_variable_answer_with_variable_scope(self):
        """Using now a variable scope"""
        self.assertListEqual(tf.get_collection(("__variable_store",)), [], 
                             "No variable store.")

        with tf.variable_scope("my_scope") as scope:
            self.assertTrue(len(tf.get_collection(("__variable_store",))) > 0, 
                            "There is now a variable store.")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "but with variable in it.")

            a = tf.Variable(0., trainable=False, name='test')
            self.assertEqual(a.name, "my_scope/test:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(var_store._vars, {},
                                 "Still no variable in the store.")


            b = tf.get_variable('test', shape=(), trainable=False)
            self.assertEqual(b.name, "my_scope/test_1:0")
            var_store = tf.get_collection(("__variable_store",))[0]
            self.assertDictEqual(
                var_store._vars, {"my_scope/test": b},
                "`b` is in the store, but notice the difference between its name and its key in the store.")

            with self.assertRaises(ValueError) as exception_context_manager:
                tf.get_variable('test', shape=(), trainable=False)
            exception = exception_context_manager.exception
            self.assertStartsWith(str(exception),
                                  "Variable my_scope/test already exists, disallowed.")

使用显式变量作用域时也是如此。

{{1}}