tf.cond降低了训练速度

时间:2017-03-30 02:53:55

标签: machine-learning tensorflow

我在tensorflow中使用了cifar10输入管道,例如tensorflow模型并尝试使用tf.cond进行验证,我写了类似这样的内容

train_data = model.input(istrain=True)
val_data = model.input(istrain=False)

# This selects which stream to use.
select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')
data = tf.cond(
    select_val,
    lambda:val_data,
    lambda:train_data
)

# Here is the model.
loss = ...
train_op = ...
...

with tf.Session():
    ...

如果我删除cond并只使用训练数据,速度为4000个样本/秒,如果我使用上面的代码,速度会降低到2300个样本/秒。验证管道容量设置得非常小,因此GPU中不会占用太多内存。进行验证的频率也很低。 我不确定出了什么问题,请帮助我。

1 个答案:

答案 0 :(得分:5)

tf.cond并非完全懒惰。即使需要它的分支不是要执行的分支,也将运行cond的任一分支所需的任何操作。因此,在您的情况下,每次调用model.input(istrain=True) op时都会执行model.input(istrain=False)data。其中一个的结果被忽略了。

documentation for cond给出了一个最小的代码示例:

  

请注意,条件执行仅适用于操作   在fn1和fn2中定义。请考虑以下简单程序:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
     

如果x&lt; y,将执行tf.add操作和tf.square   操作不会被执行。因为至少需要一个z   cond的分支,tf.mul操作总是执行,   无条件的。虽然这种行为与此一致   TensorFlow的数据流模型,它偶尔会让一些用户感到惊讶   谁想要更懒惰的语义。

另请注意,这意味着如果您的model.input从较大的池中提取一些数据(例如,来自整个数据集的批处理),则每次运行cond时,数据都会获得从验证和训练中抽出来,一套就被扔掉了。在某些情况下,这可能导致比低效率更严重的问题。例如,如果您只处理特定数量的纪元,那么使用此代码,您实际上并未处理该数量的纪元,因为数据被拉出未使用的数据。