我试图让TensorFlow的map_fn
在我的GPU上运行时遇到了一个奇怪的问题。这是一个极小的破坏的例子:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
这会导致错误:
InvalidArgumentError:无法分配设备进行操作 'map / TensorArray_1':无法满足显式设备规范 ''因为节点与一组需要的节点共存 不兼容的设备'/设备:GPU:0'主机托管调试信息:主机托管 group具有以下类型和设备:TensorArrayScatterV3:CPU TensorArrayGatherV3:GPU CPU范围:GPU CPU TensorArrayWriteV3:CPU TensorArraySizeV3:GPU CPU TensorArrayReadV3:CPU输入:GPU CPU TensorArrayV3:CPU Const:GPU CPU
主机托管成员和用户请求的设备:
map / TensorArrayStack / range / delta(Const)
map / TensorArrayStack / range / start(Const)map / TensorArray_1 (TensorArrayV3)map / while / TensorArrayWrite / TensorArrayWriteV3 / Enter (输入)/ device:GPU:0 map / TensorArrayStack / TensorArraySizeV3 (TensorArraySizeV3)map / TensorArrayStack / range(Range)
map / TensorArrayStack / TensorArrayGatherV3(TensorArrayGatherV3)
map / TensorArray(TensorArrayV3)map / while / TensorArrayReadV3 / Enter (输入)/设备:GPU:0 Const(Const)/ device:GPU:0
地图/ TensorArrayUnstack / TensorArrayScatter / TensorArrayScatterV3 (TensorArrayScatterV3)/ device:GPU:0 map / while / TensorArrayReadV3 (TensorArrayReadV3)/ device:GPU:0
map / while / TensorArrayWrite / TensorArrayWriteV3(TensorArrayWriteV3) /设备:GPU:0[[节点:map / TensorArray_1 = TensorArrayV3clear_after_read = true, dtype = DT_FLOAT,dynamic_size = false,element_shape =, identical_element_shapes = TRUE, tensor_array_name = “”]]
代码在我的CPU上运行时表现如预期,以及简单的操作,例如:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.add(test_range, test_range))
print(test)
在我的GPU上正常工作。 This post似乎描述了类似的问题。有人有任何提示吗?该帖子的答案暗示map_fn
应该可以在GPU上正常工作。我在Arch Linux上运行Python 3.6.4上的TensorFlow版本1.8.0,在GeForce GTX 1050上运行CUDA版本9.0和cuDNN版本7.0。
谢谢!
答案 0 :(得分:2)
错误实际上源于np.arange
默认生成int32
但您指定了float32
返回类型的事实。
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5, dtype=np.float32))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
我同意您收到的错误消息相当令人困惑。通过删除设备放置,您会收到“真实”错误消息:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
# InvalidArgumentError (see above for traceback): TensorArray dtype is float but Op is trying to write dtype int32.