Tensorflow:tf.get_variable如何工作?

时间:2017-07-13 07:26:00

标签: tensorflow

我已经从这个question读到了tf.get_variable,还有一些来自tensorflow网站上提供的文档。但是,我仍然不清楚,无法在网上找到答案。

tf.get_variable如何运作?例如:

var1 = tf.Variable(3.,dtype=float64)
var2 = tf.get_variable("var1",[],dtype=tf.float64)

是否 var2 另一个变量,其初始化类似于 var1 ?或者 var2 var1 的别名(我试过了,它似乎没有)?

var1 var2 如何相关?

当我们获取的变量确实存在时,如何构造变量?

2 个答案:

答案 0 :(得分:17)

tf.get_variable(name)在张量流图中创建一个名为name的新变量(或在当前作用域中已添加_ name)。

在您的示例中,您创建了一个名为var1 python 变量。

** Tensorflow图表中该变量的名称不是** var1,而是Variable:0

您定义的每个节点都有自己可以指定的名称,或者让tensorflow给出一个默认(并且始终不同)的名称。您可以看到name值访问python变量的name属性。 (即print(var1.name))。

在第二行,您要定义 Python变量 var2,其张量流图中的名称var1

脚本

import tensorflow as tf

var1 = tf.Variable(3.,dtype=tf.float64)
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

实际上是印刷品:

Variable:0
var1:0

如果您想要在tensorflow图中定义一个名为var1的变量(节点),然后获取对该节点的引用,那么不能只使用{{1}因为它会创建一个新的不同变量valled tf.get_variable("var1")

此脚本

var1_1

打印:

var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

如果您想创建对节点var1:0 var1_1:0 的引用,请先:

  1. 必须将var1替换为tf.Variable。使用tf.get_variable创建的变量无法共享,而后者则可以。

  2. 了解tf.Variable的{​​{1}}是什么,并在声明参考时允许该范围的scope

  3. 查看代码是了解

    的更好方法
    var1

    输出:

    reuse

答案 1 :(得分:1)

如果您使用之前定义的名称定义变量,则TensorFlow会引发异常。因此,使用tf.get_variable()函数代替tf.Variable()很方便。函数tf.get_variable()返回具有相同名称的现有变量(如果存在),并创建具有指定形状和初始值设定项的变量(如果不存在)。