使用tf.data.Dataset API将K(K = 5)个样本的平均值作为随机类的样本

时间:2019-01-17 15:43:03

标签: python tensorflow

我有一个tfrecord数据集文件(n_samples,feature_len,标签),用于n_samples = 1000000,feature_len = 512和标签0〜6(7个类)。 我需要使用tf.data.Dataset的api才能将数据输入到我的网络中。问题就在这里:我需要将批次中的每个样本取为相应类别中K(K = 5)个样本的平均值。

#if BATCH_SIZE=4 and K=5 then have batch_labels = [6,0,3,2] and batch_data.shape = [4, 512]
#sample 0 of batch_data is average of 5 random samples of class batch_labels[0]
#sample 1 of batch_data is average of 5 random samples of class batch_labels[1]
#sample 2 of batch_data is average of 5 random samples of class batch_labels[2]
#sample 3 of batch_data is average of 5 random samples of class batch_labels[3]

我的代码还只是给我每个数据集样本!

import tensorflow as tf
BATCH_SIZE = 4
def parser(record):
    features={'data_raw': tf.FixedLenFeature([], tf.string),
              'label_raw': tf.FixedLenFeature([], tf.string)}
    parsed = tf.parse_single_example(record, features)
    data_raw = tf.decode_raw(parsed['data_raw'], tf.float32)
    data_resized = tf.reshape(data_raw, [512,])    
    label_raw = tf.decode_raw(parsed['label_raw'], tf.uint8)
    label_resized = tf.reshape(label_raw, [1,]) 
    return data_resized, label_resized

dataset = tf.data.TFRecordDataset('dataset.tfrecords')
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(BATCH_SIZE)

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
batch_data, batch_label = sess.run(next_element)
print(batch_data.shape, batch_label.shape) # (4, 512) , (4, 1)

0 个答案:

没有答案