为什么“tf.Variable([。3],tf.float32)”在tensorflow中工作?

时间:2017-03-20 07:17:55

标签: python python-2.7 tensorflow

标准用法应为

tf.Variable([.3], dtype=tf.float32),不是吗?

我在官方文档中看到了tf.Variable([.3], tf.float32)tf.Variable的构造函数原型是

__init__(self, initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None)

如果我们传递参数tf.float32而不是dtype=tf.float32(关键参数),它如何知道tf.float32用于dtype。 python解释器检查参数类型吗?

1 个答案:

答案 0 :(得分:3)

来自tf.Variable的文档:

  

dtype:如果设置,则initial_value将转换为给定类型。           如果None,将保留数据类型(如果initial_value为           a Tensor),或convert_to_tensor将决定。

来自convert_to_tensor(value, dtype=None, ...)的文档:

  

dtype:返回张量的可选元素类型。如果失踪了,         类型是从value

的类型推断出来的

此外,文档中给出了convert_to_tensor

的示例
import numpy as np

def my_func(arg):
    arg = tf.convert_to_tensor(arg, dtype=tf.float32)
    return tf.matmul(arg, arg) + arg

# The following calls are equivalent.
value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]]))
value_2 = my_func([[1.0, 2.0], [3.0, 4.0]])
value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))

所以,回到你的问题 - Tensorflow不知道你打算使用tf.float32,它恰好是convert_to_tensor函数默认选择的数据类型。因此,返回的Tensor具有您期望的数据类型仅仅是巧合。如果是您致电tf.Variable([.3], tf.float64)后生成的Tensor与调用dtype时的tf.Variable([.3], tf.float32)相同。

事实上,我认为调用tf.Variable([.3], tf.float32)tf.Variable([.3], tf.float64)都是等价的,因为tf.Variable的第二个参数是布尔值,因此tf.floatX被转换为总是返回True的布尔值。