我有这个TensorFlow代码:
import tensorflow as tf
import numpy as np
with tf.device("gpu:0"):
sess = tf.InteractiveSession()
w = tf.Variable([1, 1, 1], dtype=tf.float32, name="weights")
examples = tf.Variable([[0, 1, 0], [0, 0, 1]], dtype=tf.float32, name="examples")
exampleId = tf.placeholder(shape=(), name="example_id", dtype=tf.int32)
updatew = tf.assign_sub(w, tf.gather_nd(examples, [exampleId]))
sess.run(tf.global_variables_initializer())
for exId in range(2):
sess.run(updatew, {exampleId:exId})
print(w.eval())
实际代码显然更复杂,但这个例子足以解决这个问题。
基本上:
G
D
的数据集D
G
中的每个示例
如果我循环GPU外部的示例(如上例所示),代码速度非常慢,比同一学习算法的当前C ++实现慢。
我想绕过D
中的示例,其中while
循环(或类似)位于图G
内部,因此位于GPU上。
我不能使用FIFO队列,因为它们只是CPU。
有没有人知道如何实现这一目标? TensorFlow有很多类可以使用'标准'来迭代迷你批次。函数,但我找不到迭代数据集并修改w
的方法。