Tensorflow数据集.map()API

时间:2018-03-14 05:45:05

标签: python tensorflow tensorflow-datasets

关于此问题的几个问题

对于我想在Tensorflow中执行以下操作的情况(假设我通过加载WAV文件创建训练示例):

import tensorflow as tf 

def _some_audio_preprocessing_func(filename):
   # ... some logic here which mostly uses Tensorflow ops ...
   with tf.Session(graph=tf.Graph()) as sess:
        wav_filename_placeholder = tf.placeholder(tf.string, [])
        wav_loader = io_ops.read_file(wav_filename_placeholder)
        wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
        data = sess.run(
                [wav_decoder],
                feed_dict={wav_filename_placeholder: filename})
        return data

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)
  1. 如果我有一个使用张量操作的parse_image()函数 - 应该 这是主图的一部分?按照示例集in Google's own audio TF tutorial,看起来他们创建了一个单独的图表!这不会破坏使用Tensorflow加快速度的重要性吗?
  2. 我是否在张量流库中的任何一行都没有使用tf.py_func()?同样,我想知道性能影响是什么以及何时应该使用它......
  3. 谢谢!

1 个答案:

答案 0 :(得分:7)

当您使用Dataset.map(map_func)时,TensorFlow为函数map_func中创建的所有操作定义子图,并安排在与图的其余部分相同的会话中有效地执行它。几乎不需要在tf.Graph内创建tf.Sessionmap_func:如果您的解析函数由TensorFlow操作组成,则这些操作可以直接嵌入到定义输入管道。

使用tf.data的代码的修改版本如下所示:

import tensorflow as tf 
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio

def _some_audio_preprocessing_func(filename):
    wav_loader = tf.read_file(filename)
    return contrib_audio.decode_wav(wav_loader, desired_channels=1)

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)

如果您的map_func包含要应用于每个元素的非TensorFlow操作,则应将它们包装在tf.py_func()(或Dataset.from_generator()中,如果定义了数据生成过程在Python逻辑中)。主要的性能含义是在tf.py_func()中运行的任何代码都受Global Interpreter Lock的约束,因此我通常建议尝试为性能至关重要的任何内容找到本机TensorFlow实现。