我正在尝试制作可提供一批TFRecords的数据集,其中一个批次将有一个类别的2条随机记录,而其他类别的2条随机记录。
OR
一个批次数据集,其中每个类别中有2个随机记录适合该批次。
我尝试用tf.data.Dataset.from_generator
和tf.data.experimental.choose_from_datasets
来做到这一点,但没有成功。您对如何执行此操作有想法吗?
编辑: 今天,我认为我实施了第二种方案。这是我对其进行测试的代码。
def input_fn():
partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
l = [partial1, partial2, partial3]
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)
choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
return choice
被撤回时返回的
[ 0 2 21 22]
[60 61 1 4]
[20 23 62 63]
[ 3 5 24 25]
[64 66 6 7]
[26 27 65 68]
[ 8 0 28 29]
[67 69 9 2]
[20 22 60 62]
[ 3 1 23 24]
[63 61 4 6]
[25 26 65 64]
[ 7 5 27 28]
[67 66 9 8]
[21 20 69 68]
答案 0 :(得分:2)
在TF 2.0中,现在可以使用dataset.interleave
读取差异类的tfrecords,并使用dataset.batch
来创建三元组对:
h = FcaeRecHelper('data/ms1m_img_ann.npy', [112, 112], 128, use_softmax=False)
len(h.train_list)
img_shape = list(h.in_hw) + [3]
is_augment = True
is_normlize = False
def parser(stream: bytes):
# parser tfrecords
examples: dict = tf.io.parse_single_example(
stream,
{'img': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)})
return tf.image.decode_jpeg(examples['img'], 3), examples['label']
def pair_parser(raw_imgs, labels):
# imgs do same augment ~
if is_augment:
raw_imgs, _ = h.augment_img(raw_imgs, None)
# normlize image
if is_normlize:
imgs: tf.Tensor = h.normlize_img(raw_imgs)
else:
imgs = tf.cast(raw_imgs, tf.float32)
imgs.set_shape([4] + img_shape)
labels.set_shape([4, ])
# Note y_true shape will be [batch,3]
return (imgs[0], imgs[1], imgs[2]), (labels[:3])
batch_size = 1
# h.train_list : ['a.tfrecords','b.tfrecords','c.tfrecords',...]
ds = (tf.data.Dataset.from_tensor_slices(h.train_list)
.interleave(lambda x: tf.data.TFRecordDataset(x)
.shuffle(100)
.repeat(), cycle_length=-1,
# block_length = 2 is important
block_length=2,
num_parallel_calls=-1)
.map(parser, -1)
.batch(4, True)
.map(pair_parser, -1)
.batch(batch_size, True))
iters = iter(ds)
for i in range(20):
imgs, labels = next(iters)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(imgs[0].numpy().astype('uint8')[0])
axs[1].imshow(imgs[1].numpy().astype('uint8')[0])
axs[2].imshow(imgs[2].numpy().astype('uint8')[0])
plt.show()
答案 1 :(得分:0)
好,我知道了。数据集已成功生成,并且数据随机性似乎不错。这不是三元组丢失的理想解决方案,因为三元组是随机的并且不是半硬的。
def input_fn(self, params):
batch_size = params['batch_size']
assert self.data_dir, 'data_dir is required'
shuffle = self.is_training
dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)
def prefetch_dataset(filename):
dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
return dataset
datasets = []
for glob in dirs:
dataset = tf.data.Dataset.list_files(glob)
dataset = dataset.apply(
tf.contrib.data.parallel_interleave(
prefetch_dataset,
cycle_length=FLAGS.num_files_infeed,
sloppy=True)) # if order is important
dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
datasets.append(dataset)
def gen(x):
return tf.data.Dataset.range(x,x+1).repeat(2)
choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)
dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)
return dataset