tensorflow control_flow_ops不能正常工作

时间:2016-09-21 07:11:52

标签: tensorflow

我有一段tensorflow代码,使用control_flow_ops.cond来选择要使用的结果:

import tensorflow as tf
import numpy as np
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.client import timeline
import time

with tf.device('/cpu:0'):
    a_arr = []
    b = tf.Variable(tf.random_normal([1400, 5600]))
    c_arr = []
    d = tf.Variable(tf.zeros([1, 5600]))
    e_arr = []
    x = tf.placeholder(tf.int32, [250])
    y = tf.placeholder(tf.int32, [250])
    tf.scalar_summary('max/x', tf.reduce_max(x)) 
    for i in range(0, 250):
        a_arr.append(tf.Variable(tf.random_normal([1, 1400])))
        #c = tf.matmul(a_arr[i], b)
        **c = control_flow_ops.cond(x[i] < y[i], lambda: tf.matmul(a_arr[i], b), lambda:d)**
        e_arr.append(c)
    summary = tf.merge_all_summaries()
    e_arr.append(summary)

init = tf.initialize_all_variables()
with tf.Session() as sess:
    train_writer = tf.train.SummaryWriter('tensor_summary/train',
                                      sess.graph)
    sess.run(init)
    xi = [ 1 for i in range(0, 250) ]
    yi = [ 0 for i in range(0, 250) ]
    print(np.sum(xi < yi))

    for i in range(1000): 
        time_s = time.time() 
        out_arr = sess.run(e_arr, feed_dict={x:xi, y:yi})
        train_writer.add_summary(out_arr[-1], 1)
        time_e = time.time()
        print('duration = %f' %(time_e - time_s))

这里tf.MatMul不应该被执行,但它实际上是执行的,我在tensorflow 0.10.0上运行,在32核CPU上运行,它使用超过900个CPU,执行时间是13ms,节省时间线数据显示tf.MatMul也被执行。  这是一个测试tensorflow control_flow_ops.cond的测试用例,也用于bidirectional_rnn。  在这种情况下如何避免执行tf.MatMul,同时仍然使用control_flow_ops.cond动态选择两个结果中的一个?  有没有设置?

1 个答案:

答案 0 :(得分:0)

MatMuls并没有真正执行,即使时间轴包含它们。这有点令人困惑,我们会考虑删除它们。如果你有和没有cond的时间,你应该能够看到差异。