我有一个从tfrecords
读取训练和验证数据集的管道
我使用tf.train.batch
构建批次。在培训期间,我想在验证数据集的培训和评估之间切换。
以下是我现在实现它的简化代码片段。
is_training_pl = tf.placeholder(tf.bool)
images_train, labels_train = tf.train.batch([img_train, label_train])
images_val, labels_val = tf.train.batch([img_val, label_val])
data = tf.cond(is_training_pl, lambda: [images_train, labels_train], lambda: [images_val, labels_val])
loss = my_model(input=data)
我知道可以使用tf.cond
来完成,但问题是当调用tf.cond
时,将执行train和val批处理操作。
在github上 ebrevdo 告诉(link to the comment),为此可以使用tf.train.maybe_batch
,这样更有效。
有人可以举例说明如何在我的案例中使用tf.train.batch
吗?