我有一个功能' read_and_decode_Train'它从TFRecords数据集读取和解码单个图像和标签。然后我使用tf.train.batch()函数将BATCH_SIZE图像和标签序列化到images_batch和labels_batch。代码如下:
image, label = read_and_decode_Train(tfRecordsName)
images_batch, labels_batch = tf.train.batch([image, label], batch_size=BATCH_SIZE, num_threads=8, capacity=2000)
现在,我想根据某些条件将TFRecords数据集划分为训练,验证和测试数据集的三个子集,例如,如果我有一个csv文件,哪些行对应于TFRecords数据集的图像和标签,然后我根据csv文件划分数据集。我修改我的程序以添加条件,如下所示:
COUNT_TRAIN = -1
def read_and_decode_Train(filename, csvLines, valNo, testNo):
'''read and decode one single image and label
from the TFRecords dataset.
'''
global COUNT_TRAIN
filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
while True:
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['img_raw'], tf.float32)
image = tf.reshape(image, [64, 64, 1])
label = features['label']
COUNT_TRAIN += 1
if csvLines[COUNT_TRAIN][1] != valNo and csvLines[COUNT_TRAIN][1] != testNo:
break
return image, label
image, label = read_and_decode_Train(tfRecordsName, csvLines, valNo, testNo)
images_batch, labels_batch = tf.train.batch([image, label], batch_size=BATCH_SIZE, num_threads=8, capacity=2000)
然而,tf.train.batch()函数似乎像以前一样读取数据。 那么,在我的情况下,如何根据某些条件从TFRecords获取数据而不是读取数据呢? 感谢您的善意建议和建议。
答案 0 :(得分:0)
您可能更容易将数据拆分为3个单独的TFRecords
文件;一个用于培训,一个用于验证,一个用于测试。