tf.while_loop的高效并行实现

时间:2018-07-12 20:36:44

标签: tensorflow while-loop parallel-processing

我想对并行执行操作的算法进行有效的实现,并将它们添加到张量流的共享和中。作为MWE,我比较了矩阵乘法的这两种实现,以检查其正确性:

import tensorflow as tf

m = 10000
n = 10000
p = 10000

A = tf.random_normal((m,n))
X = tf.random_normal((n,p))

# Option 1: multiply and sum
def tf_multiply(A, X):
    return tf.reduce_sum(tf.matmul(A, X))
y1 = tf_multiply(A, X)

# Option 2: multiply in tf using while_loop
c = lambda i, y: i < p
i0 = tf.constant(0)
y = tf.Variable(tf.constant(0.0))
def b(i, y):
    y += tf.reduce_sum(tf.matmul(A, tf.expand_dims(X[:,i], axis=-1)))
    return i+1, y
iter, y2 = tf.while_loop(cond=c, body=b, loop_vars=(i0, y), parallel_iterations=p)
with tf.Session() as sess:
    %timeit sess.run(y1)
    sess.run(y.initializer)
    %timeit sess.run(y2)

第一种方法需要602毫秒来执行操作,而第二种方法则需要27.3 s左右。实施不好吗?我可以使第二个实施更快吗?谢谢。

0 个答案:

没有答案