我正在使用tf.scatter_nd
来更新某个索引处的复杂值。
似乎真实和想象的部分以某种方式通过这个功能加在一起。我的问题是如何使其与占位符一起使用。以下是变量b
和e
应具有相同值的最小工作示例。
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])
# works with constants
a=tf.constant(update)
b=tf.scatter_nd(idx,a,shp)
with tf.Session() as sess:
print sess.run(b) # correct output: 1.+2j
#Does not work with placeholders
d=tf.placeholder(tf.complex128)
e=tf.scatter_nd(idx,d,shp)
with tf.Session() as sess:
print sess.run(e,feed_dict={d:update}) # WRONG output: 3.+0j
我正在使用conda命令安装Anaconda python 2.7 + TensorFlow 1.7 GPU版本。
修改
在GPU上运行代码时会出现此问题。 CPU版本正常工作。 以下是使用Anaconda Python 2.7安装的TensorFlow-GPU 1.8中重现问题的更新代码。
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
update=np.asarray([1.+2j])
idx=tf.constant( [[0]])
shp=tf.constant([1])
a=tf.placeholder(tf.complex128)
with tf.device("/cpu:0"):
b=tf.scatter_nd(idx,a,shp)
with tf.device("/gpu:0"):
c=tf.scatter_nd(idx,a,shp)
with tf.Session() as sess:
print 'Correct output on CPU', sess.run(b,feed_dict={a:update})
print 'Wrong output on GPU',sess.run(c,feed_dict={a:update})
我看到了this线程和this线程,但找不到如何解决它。是否有可以在GPU上运行的tf.scatter_nd
的替代方案?