在TensorFlow中更改默认的float DType

时间:2019-05-17 08:01:34

标签: tensorflow default dtype

由于我们主要使用TensorFlow进行似然拟合而不是深度学习,因此我们始终将float64用于任何操作。这意味着我们必须为大多数操作明确指定dtype。首先转换为适当的类型。实际上,当TensorFlow缺乏经验的用户实现了一个简单的功能时,例如,我们经常会遇到dtype不兼容的情况。在这里:

sqrt1 = tf.sqrt(5)  # dtype error: cannot be integer

,如果它与5(而不是5)一起使用,则sqrt1将为dtype float32,从而导致

x = tf.constant(1., dtype=tf.float64)  # defined outside
def foo(x):  # implemented by the user
    sqrt1 = tf.sqrt(5) * x
foo(x)  # raises dtype error, float32 * float64

对于高级用户而言,这很烦人,但是可以小心一点,对于刚接触TensorFlow并打算编写一个小功能的人来说,这是一个真正的选择。

是否有可能以合理的方式将默认dtype更改为tf.float64?可以假设代码中不会使用tf.float32(至少从我们这边来看,TF可能会损坏)。

如果没有,什么是不合理的选择?  -覆盖convert_to_tensor(优先级较高的寄存器)将任何float32自动转换为float64吗?  -包装TF可能需要太多工作  -将即时导入时将tf.float32更改为tf.float64?疯狂但可行吗?

我知道您可以使用Variables进行操作,问题显然是关于Tensors。

0 个答案:

没有答案