我已经在这个问题here和here上阅读了一些答案,但我仍然对tf.Variable
存在和/或不是tf.Tensor
感到有些困惑。
链接的答案处理tf.Variable
的可变性,并提到tf.Variable
维护其状态(使用默认参数trainable=True
实例化时)。
让我感到有点困惑的是我在使用tf.test.TestCase编写简单单元测试时遇到的测试用例
请考虑以下代码段。我们有一个名为Foo
的简单类,它只有一个属性,tf.Variable
初始化为w
:
import tensorflow as tf
import numpy as np
class Foo:
def __init__(self, w):
self.w = tf.Variable(w)
现在,假设您要测试Foo
的实例是否已经初始化,其张量与通过w
传递的张量相同。最简单的测试用例可以写成如下:
import tensorflow as tf
import numpy as np
from foo import Foo
class TestFoo(tf.test.TestCase):
def test_init(self):
w = np.random.rand(3,2)
foo = Foo(w)
init = tf.global_variables_initializer()
with self.test_session() as sess:
sess.run(init)
self.assertShapeEqual(w, foo.w)
if __name__ == '__main__':
tf.test.main()
现在,当您运行测试时,您将收到以下错误:
======================================================================
ERROR: test_init (__main__.TestFoo)
----------------------------------------------------------------------
Traceback (most recent call last):
File "test_foo.py", line 12, in test_init
self.assertShapeEqual(w, foo.w)
File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/test_util.py", line 1100, in assertShapeEqual
raise TypeError("tf_tensor must be a Tensor")
TypeError: tf_tensor must be a Tensor
----------------------------------------------------------------------
Ran 2 tests in 0.027s
FAILED (errors=1)
您可以通过执行以下操作来“绕过”此单元测试错误(即注释assertShapeEqual
已替换为assertEqual
):
self.assertEqual(list(w.shape), foo.w.get_shape().as_list())
我感兴趣的是tf.Variable
vs tf.Tensor
关系。
测试错误似乎表明foo.w
不一个tf.Tensor
,这意味着您可能无法在其上使用tf.Tensor
API。但是,请考虑以下交互式python会话:
$ python3
Python 3.6.3 (default, Oct 4 2017, 06:09:15)
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>> w = np.random.rand(3,2)
>>> var = tf.Variable(w)
>>> var.get_shape().as_list()
[3, 2]
>>> list(w.shape)
[3, 2]
>>>
在上面的会话中,我们创建一个变量并在其上运行get_shape()
方法以检索其形状尺寸。现在,get_shape()
方法是tf.Tensor
API方法,您可以看到here。
回到我的问题,tf.Tensor
API的哪些部分tf.Variable
会实现。如果答案是全部,为什么上述测试用例失败?
self.assertShapeEqual(w, foo.w)
与
raise TypeError("tf_tensor must be a Tensor")
我很确定我错过了一些基本的东西,或者这可能是assertShapeEqual中的错误?如果有人能对此有所了解,我将不胜感激。
我在使用tensorflow
的macOS上使用以下版本的python3
:
tensorflow (1.4.1)
答案 0 :(得分:0)
测试实用程序函数正在检查变量是否实现tf.Tensor
>>> import tensorflow as tf
>>> v = tf.Variable('v')
>>> v
<tf.Variable 'Variable:0' shape=() dtype=string_ref>
>>> isinstance(v, tf.Tensor)
False
答案似乎是'不'。
更新
根据正确的文件:
https://www.tensorflow.org/programmers_guide/variables
与tf.Tensor对象不同,tf.Variable存在于上下文之外 一个session.run调用。
虽然:
tf.Variable 表示张量,其值可以改变 在它上面运行操作。
(不太确定'代表张量'意味着什么 - 听起来像设计'特征')