使用数据集API在Tensorflow中滑动批处理窗口

时间:2018-05-30 19:34:35

标签: python tensorflow tensorflow-datasets

有没有办法在批次中修改我的图像的组成?目前,当我创建例如一个大小为4的批次,我的批次将是这样的:

Batch1:[Img0 Img1 Img2 Img3] 第2批:[Img4 Img5 Img6 Img7]

我需要修改我的批次的组成,以便它只会移动一次到下一个图像。那应该是这样的:

Batch1:[Img0 Img1 Img2 Img3] 第2批:[Img1 Img2 Img3 Img4] Batch3:[Img2 Img3 Img4 Img5] Batch4:[Img3 Img4 Img5 Img6] Batch5:[Img4 Img5 Img6 Img7]

我在我的代码中使用Tensorflow的数据集API,如下所示:

def tfrecords_train_input(input_dir, examples, epochs, nsensors, past, future,
                          features, batch_size, threads, shuffle, record_type):
    filenames = sorted(
        [os.path.join(input_dir, f) for f in os.listdir(input_dir)])
      num_records = 0
      for fn in filenames:
        for _ in tf.python_io.tf_record_iterator(fn):
          num_records += 1
      print("Number of files to use:", len(filenames), "/ Total records to use:", num_records)
      dataset = tf.data.TFRecordDataset(filenames)
      # Parse records
      read_proto = partial(record_type().read_proto, nsensors=nsensors, past=past,
                           future=future, features=features)
      # Parallelize Data Transformation on available GPU
      dataset = dataset.map(map_func=read_proto, num_parallel_calls=threads)
      # Cache data
      dataset = dataset.cache()
      # repeat after shuffling
      dataset = dataset.repeat(epochs)
      # Batch data
      dataset = dataset.batch(batch_size)
      # Efficient Pipelining
      dataset = dataset.prefetch(2)
      iterator = dataset.make_one_shot_iterator()
      return iterator

3 个答案:

答案 0 :(得分:8)

可以使用sliding window的{​​{1}}批量操作来实现:

示例:

tf.data.Dataset

<强>输出:

from tensorflow.contrib.data.python.ops import sliding

imgs = tf.constant(['img0','img1', 'img2','img3', 'img4','img5', 'img6', 'img7'])
labels = tf.constant([0, 0, 0, 1, 1, 1, 0, 0])

# create TensorFlow Dataset object
data = tf.data.Dataset.from_tensor_slices((imgs, labels))

# sliding window batch
window = 4
stride = 1
data = data.apply(sliding.sliding_window_batch(window, stride))

# create TensorFlow Iterator object
iterator =  tf.data.Iterator.from_structure(data.output_types,data.output_shapes)
next_element = iterator.get_next()

# create initialization ops 
init_op = iterator.make_initializer(data)

with tf.Session() as sess:
   # initialize the iterator on the data
   sess.run(init_op)
   while True:
      try:
         elem = sess.run(next_element)
         print(elem)
      except tf.errors.OutOfRangeError:
         print("End of dataset.")
         break

答案 1 :(得分:2)

在tensorflow> = 2.1的情况下,可以使用window(),flat_map()和batch()函数来获得所需的结果。

示例:

## Sample data list
x_train = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90]

## Constants
batch_size = 10
shift_window_size = 1

## Create tensor slices
train_d = tf.data.Dataset.from_tensor_slices(x_train)

## Create dataset of datasets with a specific window and shift size
train_d = train_d.window(size=batch_size,shift=shift_window_size, drop_remainder=True)

## Define a function to create a flat dataset from the dataset of datasets
def create_seqeunce_ds(chunk):
    return chunk.batch(batch_size, drop_remainder=True)

## Create a dataset using a map with mapping function defined above
train_d = train_d.flat_map(create_seqeunce_ds)

## Check the contents
for item in train_d:
    print(item)

输出:

tf.Tensor([ 1  2  3  4  5  6  7  8  9 10], shape=(10,), dtype=int32)
tf.Tensor([ 2  3  4  5  6  7  8  9 10 20], shape=(10,), dtype=int32)
tf.Tensor([ 3  4  5  6  7  8  9 10 20 30], shape=(10,), dtype=int32)
tf.Tensor([ 4  5  6  7  8  9 10 20 30 40], shape=(10,), dtype=int32)
tf.Tensor([ 5  6  7  8  9 10 20 30 40 50], shape=(10,), dtype=int32)
tf.Tensor([ 6  7  8  9 10 20 30 40 50 60], shape=(10,), dtype=int32)
tf.Tensor([ 7  8  9 10 20 30 40 50 60 70], shape=(10,), dtype=int32)
tf.Tensor([ 8  9 10 20 30 40 50 60 70 80], shape=(10,), dtype=int32)
tf.Tensor([ 9 10 20 30 40 50 60 70 80 90], shape=(10,), dtype=int32)

更多详细信息请参见:TF Data Guide

答案 2 :(得分:0)

回答原始帖子和回答@cabbage_soup对vijay的回复的评论:

要获得有效的滑动窗口,可以使用以下代码。

data = data.window(size=batch_size, stride=1, shift=1, drop_remainder=True ) data = data.interleave( lambda *window: tf.data.Dataset.zip(tuple([w.batch(batch_size) for w in window])), cycle_length=10, block_length=10 ,num_parallel_calls=4 )

使用交错而不是flat_map,因为它允许在此窗口转换期间并行进行处理。

请参考文档,为您的硬件和数据选择适当的cycle_length,block_length和num_parallel_calls值。