我想在非渴望模式下测试张量的类型。在急切模式下,此代码有效:
tf.enable_eager_execution()
t = tf.random.normal([1, 1])
num_type = t.dtype
print(num_type == tf.float32) # prints `True`
在非渴望模式下,相同的代码不会,并且我发现测试的唯一方法是丑陋的str(num_type) == "float32"
:
sess = tf.Session()
t = sess.run(tf.random.normal([1, 1]))
num_type = t.dtype
print(num_type) # prints `float32`
print(str(num_type) == "float32") # prints `True`
print(num_type == float32) # returns `NameError: name 'float32' is not defined`
print(num_type == tf.float32) # returns `TypeError: data type not understood`
并且如果我尝试在会话中获取Tensor的类型:
t = tf.random.normal([1, 1])
t_type = t.dtype
num_type = sess.run(t_type)
然后我得到:
TypeError: Fetch argument tf.float32 has invalid type <class 'tensorflow.python.framework.dtypes.DType'>, must be a string or Tensor. (Can not convert a DType into a Tensor or Operation.)
如何在非迫切模式下测试float32
类型?
答案 0 :(得分:1)
在会话中评估张量后,您得到了一个麻木的张量对象,而不是tf.Tensor对象(直接在eager模式下使用)。
因此,您的测试应该是:
t.dtype == np.float32