为什么tf.Variable([8.0],tf.float64)的类型在TensorFlow中浮动32而不是浮动64?

时间:2017-06-05 20:12:48

标签: types casting tensorflow

(更新后的问题) 我认为最初的问题令人困惑,我找到了一种更简单的方式来提出这个问题。

#!/usr/bin/python

import tensorflow as tf
x = tf.Variable([2], tf.float32)
print x.dtype

如果我们尝试上面的代码段,那么输出如下:

<dtype: 'int32_ref'>

因为我明确地将x的类型指定为tf.float32,所以我认为该类型应该是float32。但是,似乎类型是int32。

有人可以回答这个问题吗?

(原始问题)

我尝试使用以下代码替换2-D张量流数组的一个元素。

#!/usr/bin/python

import tensorflow as tf
import numpy as np

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float64))
indices = tf.constant([[2, 2]])
updates = tf.Variable([8.0], tf.float64)
ref = tf.scatter_nd_update(ref, indices, updates)

with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)
  print sess.run(ref)

奇怪的是,我遇到了以下类型错误:

TypeError:输入&#39;更新&#39; &#39; ScatterNdUpdate&#39; Op的类型为float32,与参数&#39; ref&#39;的类型float64不匹配。

tf.Variable([8.0], tf.float64)更改为以下行后,就可以了。

updates = tf.Variable(np.array([8.0]).astype(np.float64), tf.float64)

所以,似乎tf.Variable([8.0], tf.float64)的类型不是tf.float64,即使我明确地将类型指定为tf.float64。谁能告诉我原因?谢谢!

2 个答案:

答案 0 :(得分:3)

原因很简单:您的代码会创建一个可训练的tf.Variabletf.float64被解释为True参数的trainable。如果您只是添加dtype,它会起作用:

    updates = tf.Variable([8.0], dtype=tf.float64)

实际上,有类似的Q&A

答案 1 :(得分:0)

更新

使用tf.Variable()

时使用dtype

原始问题:

ref 类型设为float32对我有用。

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float32))

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)),dtype=tf.float32)

我猜,ScatterNdUpdate操作仅适用于float32而不是float64。