(更新后的问题) 我认为最初的问题令人困惑,我找到了一种更简单的方式来提出这个问题。
#!/usr/bin/python
import tensorflow as tf
x = tf.Variable([2], tf.float32)
print x.dtype
如果我们尝试上面的代码段,那么输出如下:
<dtype: 'int32_ref'>
因为我明确地将x
的类型指定为tf.float32,所以我认为该类型应该是float32。但是,似乎类型是int32。
有人可以回答这个问题吗?
(原始问题)
我尝试使用以下代码替换2-D张量流数组的一个元素。
#!/usr/bin/python
import tensorflow as tf
import numpy as np
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float64))
indices = tf.constant([[2, 2]])
updates = tf.Variable([8.0], tf.float64)
ref = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print sess.run(ref)
奇怪的是,我遇到了以下类型错误:
TypeError:输入&#39;更新&#39; &#39; ScatterNdUpdate&#39; Op的类型为float32,与参数&#39; ref&#39;的类型float64不匹配。
将tf.Variable([8.0], tf.float64)
更改为以下行后,就可以了。
updates = tf.Variable(np.array([8.0]).astype(np.float64), tf.float64)
所以,似乎tf.Variable([8.0], tf.float64)
的类型不是tf.float64,即使我明确地将类型指定为tf.float64。谁能告诉我原因?谢谢!
答案 0 :(得分:3)
原因很简单:您的代码会创建一个可训练的tf.Variable
(tf.float64
被解释为True
参数的trainable
。如果您只是添加dtype
,它会起作用:
updates = tf.Variable([8.0], dtype=tf.float64)
实际上,有类似的Q&A。
答案 1 :(得分:0)
更新
使用tf.Variable()
时使用dtype原始问题:
将 ref 类型设为float32对我有用。
做
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float32))
或
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)),dtype=tf.float32)
我猜,ScatterNdUpdate操作仅适用于float32而不是float64。