我需要在每个训练时代结束时进行验证。为此,我计划使用tf.cond
在从tf.data.TFRecordDataset
读取的培训和验证数据之间进行选择。原始数据将通过map
,shuffle
,repeat
和batch
函数获取next_element
中的张量列表(请参阅下面的代码) )。
import tensorflow as tf
def read_from_tfrecords(
pred, # tf.bool
## parameters for pred==True
filenames_A,
batch_size_A = 20,
num_epochs_A = None,
buffer_size_A = 5000,
seed_A = None,
## parameters for pred==False
filenames_B,
batch_size_B = 20,
num_epochs_B = None,
buffer_size_B = 5000,
seed_B = None
):
filenames = tf.cond(
pred,
lambda: tf.constant(filenames_A, dtype=tf.string),
lambda: tf.constant(filenames_B, dtype=tf.string)
)
batch_size = tf.cond(
pred,
lambda: tf.constant(batch_size_A, dtype=tf.int64),
lambda: tf.constant(batch_size_B, dtype=tf.int64)
)
num_epochs = tf.cond(
pred,
lambda: tf.constant(num_epochs_A, dtype=tf.int64),
lambda: tf.constant(num_epochs_B, dtype=tf.int64)
)
buffer_size = tf.cond(
pred,
lambda: tf.constant(buffer_size_A, dtype=tf.int64),
lambda: tf.constant(buffer_size_B, dtype=tf.int64)
)
#-------------------------------------------------------#
## When either seed_A or seed_B is None,
## error "ValueError: None values not supported." is raised.
seed = tf.cond(
pred,
lambda: tf.constant(seed_A, dtype=tf.int64),
lambda: tf.constant(seed_B, dtype=tf.int64)
)
#-------------------------------------------------------#
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_some_parse_function)
dataset = dataset.shuffle(buffer_size=buffer_size, seed=seed)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
return iterator, next_element
除pred
(tf.bool
)之外的所有参数都作为原始python类型传入。参数seed_A
和seed_B
(在shuffle
函数中)同时接受None
和integer
。但是将None
转换为tf.int64
会引发错误ValueError: None values not supported.
有没有办法将None
转换为张量?
提前致谢。
答案 0 :(得分:1)
数据集管道为您提供了比tf.cond
更好的方法。
在程序员指南中搜索tf.data.Iterator.from_string_handle
,在可输入迭代器的描述下有一个完整的例子。
https://www.tensorflow.org/programmers_guide/datasets
您可以定义2个数据集,一个用于您的列车,另一个用于测试。您为每个创建一个迭代器,然后创建一个伞形迭代器,它可以从一个或另一个读取,具体取决于您使用feed_dict
传入的简单字符串句柄。
请注意此方法的一个重要好处:您通常希望对您不希望应用于测试数据的训练数据进行数据扩充。然后,您可以使用tf.cond
方法运行数据扩充或不运行。但是从那条已经走过那条路线的人那里拿走它并为此烦恨自己,你会遇到很多陷阱,而且会有很糟糕的调试。
我现在以这种方式定义所有数据集。它使整个过程更容易理解,更容易调试。
请注意,列车数据集通常配置为ds.repeat()
,而测试数据集未配置repeat
。运行测试数据集时,您需要捕获OutOfRangeError
,它表示数据的结束。然后,您可以在下次使用迭代时重新初始化测试数据集。