并行化tf.data.Dataset.from_generator

时间:2017-11-03 00:16:33

标签: tensorflow tensorflow-datasets

我有一个非常简单的输入管道from_generator非常适合......

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,
                                        (tf.int32, tf.string))
dataset = dataset.batch(64)
iter = dataset.make_one_shot_iterator()
imgs, labels = iter.get_next()

complex_img_label_generator动态生成图像并返回表示(H, W, 3)图像和简单string标签的numpy数组。处理不是我可以表示从文件和tf.image操作中读取的内容。

我的问题是关于如何平衡发电机?我如何让N个这些生成器在自己的线程中运行。

一种想法是使用dataset.mapnum_parallel_calls来处理线程;但地图是在张量上运行的......另一个想法是创建多个生成器,每个生成器都有自己的prefetch并以某种方式加入它们,但我无法看到我如何加入N发电机流?

我可以遵循任何规范的例子吗?

3 个答案:

答案 0 :(得分:21)

事实证明,如果我使生成器超级轻量级​​(仅生成元数据),然后将实际重度照明移动到无状态函数,我可以使用Dataset.map。通过这种方式,我可以使用.map py_func将重型提升部件与num_parallel_calls并行化。

作品;但感觉有点笨拙...能够将from_generator添加到def pure_numpy_and_pil_complex_calculation(metadata, label): # some complex pil and numpy work nothing to do with tf ... dataset = tf.data.Dataset.from_generator(lightweight_generator, output_types=(tf.string, # metadata tf.string)) # label def wrapped_complex_calulation(metadata, label): return tf.py_func(func = pure_numpy_and_pil_complex_calculation, inp = (metadata, label), Tout = (tf.uint8, # (H,W,3) img tf.string)) # label dataset = dataset.map(wrapped_complex_calulation, num_parallel_calls=8) dataset = dataset.batch(64) iter = dataset.make_one_shot_iterator() imgs, labels = iter.get_next() 会很棒:)

<script>
        $(document).ready(function () {
            $('path').mouseup(function () {
                document.getElementById('state').innerHTML = $(this).attr('aria-label');
                var state_lbl = document.getElementById('state').innerHTML = $(this).attr('aria-label');
                loadstate(state_lbl);

            })
        });

        function loadstate(state_lal) {
            $.ajax({
                type: "POST",
                url: "mapreq",
                data: {'state': state_lal}
            });
        }
    </script>

答案 1 :(得分:6)

我正在为from_indexable https://github.com/tensorflow/tensorflow/issues/14448

tf.data.Dataset工作

from_indexable的优点是它可以并行化,而python生成器无法并行化。

函数from_indexable生成tf.data.range,将可索引包装在通用tf.py_func中并调用map。

对于那些现在需要from_indexable的人,这里是lib代码

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args):
            nonlocal output_shapes

            flat_output_types = nest.flatten(output_types)
            flat_values = tf.py_func(
                func, 
                inp=args, 
                Tout=flat_output_types,
                stateful=stateful, name=name
            )
            if output_shapes is not None:
                # I am not sure if this is nessesary
                output_shapes = nest.map_structure_up_to(
                    output_types, tensor_shape.as_shape, output_shapes)
                flattened_shapes = nest.flatten_up_to(output_types, output_shapes)
                for ret_t, shape in zip(flat_values, flattened_shapes):
                    ret_t.set_shape(shape)
            return nest.pack_sequence_as(output_types, flat_values)
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

这里有一个例子(注意:from_indexable有一个num_parallel_calls argument

class PyDataSet:
    def __len__(self):
        return 20

    def __getitem__(self, item):
        return np.random.normal(size=(item+1, 10))

ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10])
it = ds.make_one_shot_iterator()
entry = it.get_next()
with tf.Session() as sess:
    print(sess.run(entry).shape)
    print(sess.run(entry).shape)

更新 2018年6月10日: 自https://github.com/tensorflow/tensorflow/pull/15121合并后,from_indexable的代码简化为:

import tensorflow as tf

def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):
    def decorator(func):
        def call(*args, **kwargs):
            return tf.contrib.framework.py_func(
                func=func, 
                args=args, kwargs=kwargs, 
                output_types=output_types, output_shapes=output_shapes, 
                stateful=stateful, name=name
            )
        return call
    return decorator

def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):
    ds = tf.data.Dataset.range(len(iterator))
    @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)
    def index_to_entry(index):
        return iterator[index]    
    return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls)

答案 2 :(得分:3)

generator中完成的工作限制到最低限度并使用map并行化昂贵的处理工作是明智的。

或者,您可以使用parallel_interleave“加入”多个生成器,如下所示:

def generator(n):
  # returns n-th generator function

def dataset(n):
  return tf.data.Dataset.from_generator(generator(n))

ds = tf.data.Dataset.range(N).apply(tf.contrib.data.parallel_interleave(dataset, cycle_lenght=N))

# where N is the number of generators you use