如何在每次迭代中仅从一个类中采样批次

时间:2018-07-27 10:08:19

标签: python tensorflow

我想在一个ImageNet数据集上训练一个分类器(1000个类,每个类包含大约1300张图像)。由于某些原因,我需要每个批次包含相同类别的64个图像,以及不同类别的连续批次。最新的TensorFlow是否可能(高效)?

TF 1.9中的

tf.contrib.data.sample_from_datasets允许从tf.data.Dataset个对象列表中进行采样,其中weights表示概率。我想知道以下想法是否有意义:

  • 将每个类的数据另存为单独的tfrecord文件。
  • 传递一个tf.data.Dataset.from_generator对象作为weights。对象从分类分布中采样,因此每个采样看起来像[0,...,0,1,0,...,0],具有999个0和1个1
  • 创建1000个tf.data.Dataset对象,每个对象都链接了一个tfrecord文件。

我以这种方式认为,sample_from_datasets可能会在每次迭代时首先对一个稀疏权向量进行采样,该向量指示从哪个tf.data.Dataset进行采样,然后再从该类中采样。

对吗?还有其他有效的方法吗?

更新

如P-Gn所建议的那样,从一种类别中采样数据的一种方法是:

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(some_parser_fun)  # parse one datum from tfrecord
dataset = dataset.shuffle(buffer_size)

if sample_same_class:
    group_fun = tf.contrib.data.group_by_window(
        key_func=lambda data_x, data_y: data_y,
        reduce_func=lambda key, d: d.batch(batch_size),
        window_size=batch_size)
    dataset = dataset.apply(group_fun)
else:
    dataset = dataset.batch(batch_size)

dataset = dataset.repeat()
data_batch = dataset.make_one_shot_iterator().get_next()

可以在How to sample batch from a specific class?上找到后续问题

1 个答案:

答案 0 :(得分:3)

如果我正确理解,我认为您的解决方案将行不通,因为sample_from_dataset期望其weights而不是Tensor的值列表。

但是,如果您不介意提议的解决方案中有1000个Dataset,那么我建议简单地

  • 每个课程创建一个Dataset
  • batch每个数据集-每个批次都有来自单个类别的样本,
  • zip将它们全部分成一大批Dataset
  • shuffle这个Dataset —改组将在批次上发生,而不是在样本上发生,因此不会改变批次是单一类的事实。

一种更复杂的方法是依靠tf.contrib.data.group_by_window。让我用一个综合的例子来说明这一点。

import numpy as np
import tensorflow as tf

def gen():
  while True:
    x = np.random.normal()
    label = np.random.randint(10)
    yield x, label

batch_size = 4
batch = (tf.data.Dataset
  .from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
  .apply(tf.contrib.data.group_by_window(
    key_func=lambda x, label: label,
    reduce_func=lambda key, d: d.batch(batch_size),
    window_size=batch_size))
  .make_one_shot_iterator()
  .get_next())

sess = tf.InteractiveSession()
sess.run(batch)
# (array([ 0.04058843,  0.2843775 , -1.8626076 ,  1.1154234 ], dtype=float32),
# array([6, 6, 6, 6], dtype=int64))
sess.run(batch)
# (array([ 1.3600663,  0.5935658, -0.6740045,  1.174328 ], dtype=float32),
# array([3, 3, 3, 3], dtype=int64))