在张量中将None传递给Tensorflow数据集shuffle函数

时间:2018-02-03 01:19:23

标签: tensorflow

我需要在每个训练时代结束时进行验证。为此,我计划使用tf.cond在从tf.data.TFRecordDataset读取的培训和验证数据之间进行选择。原始数据将通过mapshufflerepeatbatch函数获取next_element中的张量列表(请参阅下面的代码) )。

import tensorflow as tf

def read_from_tfrecords(

    pred, # tf.bool

    ## parameters for pred==True
    filenames_A, 
    batch_size_A = 20,
    num_epochs_A = None,
    buffer_size_A = 5000,
    seed_A = None,

    ## parameters for pred==False
    filenames_B,
    batch_size_B = 20, 
    num_epochs_B = None,
    buffer_size_B = 5000,
    seed_B = None

    ):

    filenames = tf.cond(
        pred,
        lambda: tf.constant(filenames_A, dtype=tf.string),
        lambda: tf.constant(filenames_B, dtype=tf.string)
        )

    batch_size = tf.cond(
        pred,
        lambda: tf.constant(batch_size_A, dtype=tf.int64),
        lambda: tf.constant(batch_size_B, dtype=tf.int64)
        )

    num_epochs = tf.cond(
        pred,
        lambda: tf.constant(num_epochs_A, dtype=tf.int64),
        lambda: tf.constant(num_epochs_B, dtype=tf.int64)
        )

    buffer_size = tf.cond(
        pred,
        lambda: tf.constant(buffer_size_A, dtype=tf.int64),
        lambda: tf.constant(buffer_size_B, dtype=tf.int64)
        )

    #-------------------------------------------------------#
    ## When either seed_A or seed_B is None,
    ## error "ValueError: None values not supported." is raised.
    seed = tf.cond(
        pred,
        lambda: tf.constant(seed_A, dtype=tf.int64),
        lambda: tf.constant(seed_B, dtype=tf.int64)
        )
    #-------------------------------------------------------#

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(_some_parse_function)
    dataset = dataset.shuffle(buffer_size=buffer_size, seed=seed)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)

    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    return iterator, next_element

predtf.bool)之外的所有参数都作为原始python类型传入。参数seed_Aseed_B(在shuffle函数中)同时接受Noneinteger。但是将None转换为tf.int64会引发错误ValueError: None values not supported.有没有办法将None转换为张量?

提前致谢。

1 个答案:

答案 0 :(得分:1)

数据集管道为您提供了比tf.cond更好的方法。

在程序员指南中搜索tf.data.Iterator.from_string_handle,在可输入迭代器的描述下有一个完整的例子。

https://www.tensorflow.org/programmers_guide/datasets

您可以定义2个数据集,一个用于您的列车,另一个用于测试。您为每个创建一个迭代器,然后创建一个伞形迭代器,它可以从一个或另一个读取,具体取决于您使用feed_dict传入的简单字符串句柄。

请注意此方法的一个重要好处:您通常希望对您不希望应用于测试数据的训练数据进行数据扩充。然后,您可以使用tf.cond方法运行数据扩充或不运行。但是从那条已经走过那条路线的人那里拿走它并为此烦恨自己,你会遇到很多陷阱,而且会有很糟糕的调试。

我现在以这种方式定义所有数据集。它使整个过程更容易理解,更容易调试。

请注意,列车数据集通常配置为ds.repeat(),而测试数据集未配置repeat。运行测试数据集时,您需要捕获OutOfRangeError,它表示数据的结束。然后,您可以在下次使用迭代时重新初始化测试数据集。