如何根据某些条件使用tf.train.batch()获取一批数据?

时间:2016-09-01 03:07:55

标签: tensorflow

我有一个功能' read_and_decode_Train'它从TFRecords数据集读取和解码单个图像和标签。然后我使用tf.train.batch()函数将BATCH_SIZE图像和标签序列化到images_batch和labels_batch。代码如下:

image, label = read_and_decode_Train(tfRecordsName)
images_batch, labels_batch = tf.train.batch([image, label], batch_size=BATCH_SIZE, num_threads=8, capacity=2000)

现在,我想根据某些条件将TFRecords数据集划分为训练,验证和测试数据集的三个子集,例如,如果我有一个csv文件,哪些行对应于TFRecords数据集的图像和标签,然后我根据csv文件划分数据集。我修改我的程序以添加条件,如下所示:

COUNT_TRAIN = -1    
def read_and_decode_Train(filename, csvLines, valNo, testNo):
    '''read and decode one single image and label
       from the TFRecords dataset.
    '''
    global COUNT_TRAIN
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    while True:
        features = tf.parse_single_example(
            serialized_example,
            features={
                'label': tf.FixedLenFeature([], tf.int64),
                'img_raw': tf.FixedLenFeature([], tf.string)
            })
        image = tf.decode_raw(features['img_raw'], tf.float32)
        image = tf.reshape(image, [64, 64, 1])
        label = features['label']
        COUNT_TRAIN += 1
        if csvLines[COUNT_TRAIN][1] != valNo and csvLines[COUNT_TRAIN][1] != testNo:
            break
    return image, label

image, label = read_and_decode_Train(tfRecordsName, csvLines, valNo, testNo)
images_batch, labels_batch = tf.train.batch([image, label], batch_size=BATCH_SIZE, num_threads=8, capacity=2000)

然而,tf.train.batch()函数似乎像以前一样读取数据。 那么,在我的情况下,如何根据某些条件从TFRecords获取数据而不是读取数据呢? 感谢您的善意建议和建议。

1 个答案:

答案 0 :(得分:0)

您可能更容易将数据拆分为3个单独的TFRecords文件;一个用于培训,一个用于验证,一个用于测试。