我在tensorflow
中使用了cifar10
输入管道,例如tensorflow
模型并尝试使用tf.cond
进行验证,我写了类似这样的内容
train_data = model.input(istrain=True)
val_data = model.input(istrain=False)
# This selects which stream to use.
select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')
data = tf.cond(
select_val,
lambda:val_data,
lambda:train_data
)
# Here is the model.
loss = ...
train_op = ...
...
with tf.Session():
...
如果我删除cond并只使用训练数据,速度为4000个样本/秒,如果我使用上面的代码,速度会降低到2300个样本/秒。验证管道容量设置得非常小,因此GPU中不会占用太多内存。进行验证的频率也很低。 我不确定出了什么问题,请帮助我。
答案 0 :(得分:5)
tf.cond
并非完全懒惰。即使需要它的分支不是要执行的分支,也将运行cond
的任一分支所需的任何操作。因此,在您的情况下,每次调用model.input(istrain=True)
op时都会执行model.input(istrain=False)
和data
。其中一个的结果被忽略了。
documentation for cond
给出了一个最小的代码示例:
请注意,条件执行仅适用于操作 在fn1和fn2中定义。请考虑以下简单程序:
z = tf.multiply(a, b) result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
如果x&lt; y,将执行tf.add操作和tf.square 操作不会被执行。因为至少需要一个z cond的分支,tf.mul操作总是执行, 无条件的。虽然这种行为与此一致 TensorFlow的数据流模型,它偶尔会让一些用户感到惊讶 谁想要更懒惰的语义。
另请注意,这意味着如果您的model.input
从较大的池中提取一些数据(例如,来自整个数据集的批处理),则每次运行cond
时,数据都会获得从验证和训练中抽出来,一套就被扔掉了。在某些情况下,这可能导致比低效率更严重的问题。例如,如果您只处理特定数量的纪元,那么使用此代码,您实际上并未处理该数量的纪元,因为数据被拉出未使用的数据。