tf.Variable是否是一个张量?

时间:2018-01-09 16:55:04

标签: unit-testing tensorflow

我已经在这个问题herehere上阅读了一些答案,但我仍然对tf.Variable存在和/或不是tf.Tensor感到有些困惑。 链接的答案处理tf.Variable的可变性,并提到tf.Variable维护其状态(使用默认参数trainable=True实例化时)。

让我感到有点困惑的是我在使用tf.test.TestCase编写简单单元测试时遇到的测试用例

请考虑以下代码段。我们有一个名为Foo的简单类,它只有一个属性,tf.Variable初始化为w

import tensorflow as tf
import numpy as np

class Foo:
    def __init__(self, w):
        self.w = tf.Variable(w)

现在,假设您要测试Foo的实例是否已经初始化,其张量与通过w传递的张量相同。最简单的测试用例可以写成如下:

import tensorflow as tf
import numpy as np
from foo import Foo

class TestFoo(tf.test.TestCase):
    def test_init(self):
        w = np.random.rand(3,2)
        foo = Foo(w)
        init = tf.global_variables_initializer()
        with self.test_session() as sess:
            sess.run(init)
            self.assertShapeEqual(w, foo.w)

if __name__ == '__main__':
      tf.test.main()

现在,当您运行测试时,您将收到以下错误:

======================================================================
ERROR: test_init (__main__.TestFoo)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "test_foo.py", line 12, in test_init
    self.assertShapeEqual(w, foo.w)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/test_util.py", line 1100, in assertShapeEqual
    raise TypeError("tf_tensor must be a Tensor")
TypeError: tf_tensor must be a Tensor

----------------------------------------------------------------------
Ran 2 tests in 0.027s

FAILED (errors=1)

您可以通过执行以下操作来“绕过”此单元测试错误(即注释assertShapeEqual已替换为assertEqual):

self.assertEqual(list(w.shape), foo.w.get_shape().as_list())

我感兴趣的是tf.Variable vs tf.Tensor关系。 测试错误似乎表明foo.w 一个tf.Tensor,这意味着您可能无法在其上使用tf.Tensor API。但是,请考虑以下交互式python会话:

$ python3
Python 3.6.3 (default, Oct  4 2017, 06:09:15)
[GCC 4.2.1 Compatible Apple LLVM 9.0.0 (clang-900.0.37)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> import numpy as np
>>> w = np.random.rand(3,2)
>>> var = tf.Variable(w)
>>> var.get_shape().as_list()
[3, 2]
>>> list(w.shape)
[3, 2]
>>>

在上面的会话中,我们创建一个变量并在其上运行get_shape()方法以检索其形状尺寸。现在,get_shape()方法是tf.Tensor API方法,您可以看到here

回到我的问题,tf.Tensor API的哪些部分tf.Variable会实现。如果答案是全部,为什么上述测试用例失败?

self.assertShapeEqual(w, foo.w)

raise TypeError("tf_tensor must be a Tensor")

我很确定我错过了一些基本的东西,或者这可能是assertShapeEqual中的错误?如果有人能对此有所了解,我将不胜感激。

我在使用tensorflow的macOS上使用以下版本的python3

tensorflow (1.4.1)

1 个答案:

答案 0 :(得分:0)

测试实用程序函数正在检查变量是否实现tf.Tensor

>>> import tensorflow as tf
>>> v = tf.Variable('v')
>>> v
<tf.Variable 'Variable:0' shape=() dtype=string_ref>
>>> isinstance(v, tf.Tensor)
False

答案似乎是'不'。

更新

根据正确的文件:

https://www.tensorflow.org/programmers_guide/variables

  

与tf.Tensor对象不同,tf.Variable存在于上下文之外   一个session.run调用。

虽然:

  

tf.Variable 表示张量,其值可以改变   在它上面运行操作。

(不太确定'代表张量'意味着什么 - 听起来像设计'特征')