TensorFlow非渴望模式下的float32测试类型

时间:2019-02-01 14:59:23

标签: tensorflow floating-point

我想在非渴望模式下测试张量的类型。在急切模式下,此代码有效:

tf.enable_eager_execution()
t = tf.random.normal([1, 1])
num_type = t.dtype
print(num_type == tf.float32) # prints `True`

在非渴望模式下,相同的代码不会,并且我发现测试的唯一方法是丑陋的str(num_type) == "float32"

sess = tf.Session()

t = sess.run(tf.random.normal([1, 1]))

num_type = t.dtype
print(num_type) # prints `float32`
print(str(num_type) == "float32") # prints `True`
print(num_type == float32) # returns `NameError: name 'float32' is not defined`
print(num_type == tf.float32) # returns `TypeError: data type not understood`

并且如果我尝试在会话中获取Tensor的类型:

t = tf.random.normal([1, 1])
t_type = t.dtype

num_type = sess.run(t_type)

然后我得到:

TypeError: Fetch argument tf.float32 has invalid type <class 'tensorflow.python.framework.dtypes.DType'>, must be a string or Tensor. (Can not convert a DType into a Tensor or Operation.)

如何在非迫切模式下测试float32类型?

1 个答案:

答案 0 :(得分:1)

在会话中评估张量后,您得到了一个麻木的张量对象,而不是tf.Tensor对象(直接在eager模式下使用)。

因此,您的测试应该是:

t.dtype == np.float32