从迭代器

时间:2017-05-03 13:11:51

标签: python tensorflow

大多数教程都关注整个训练数据集适合内存的情况。但是,我有一个迭代器,它充当(特征,标签)-tuples的无限流(在运行中便宜地创建它们)。

为tensorflow estimator实现input_fn时,我可以从迭代器返回一个实例

def input_fn():
   (feature_batch, label_batch) = next(it)
   return tf.constant(feature_batch), tf.constant(label_batch)

input_fn必须在每次通话时返回相同的(功能,标签) - 元组吗?

此外,在训练期间多次调用此函数,因为我希望它类似于以下伪代码:

for i in range(max_iter):
   learn_op(input_fn())

3 个答案:

答案 0 :(得分:2)

input_fn的参数在整个训练中使用,但函数本身被调用一次。因此,如tutorial中所述,创建一个复杂的input_fn不仅仅是返回一个常量数组也不是那么简单。

Tensorflow为numpypanda数组提供了两个非平凡的input_fn示例,但它们从内存中的数组开始,因此这对您的问题没有帮助。

您还可以按照上面的链接查看他们的代码,看看他们如何实现有效的非平凡input_fn,但您可能会发现它需要更多您想要的代码。

如果您愿意使用Tensorflow的低级别界面,那么IMHO更简单,更灵活。有一个tutorial可以满足大多数需求,建议的解决方案很容易实现。

特别是,如果您已经有一个按照问题中描述的方式返回数据的迭代器,那么使用占位符(上一个链接中的“Feeding”部分)应该很简单。

答案 1 :(得分:2)

我找到了一个将generator转换为input_fn的拉取请求: https://github.com/tensorflow/tensorflow/pull/7045/files

相关部分是

  def _generator_input_fn():
    """generator input function."""
    queue = feeding_functions.enqueue_data(
      x,
      queue_capacity,
      shuffle=shuffle,
      num_threads=num_threads,
      enqueue_size=batch_size,
      num_epochs=num_epochs)

    features = (queue.dequeue_many(batch_size) if num_epochs is None
                else queue.dequeue_up_to(batch_size))
    if not isinstance(features, list):
      features = [features]
    features = dict(zip(input_keys, features))
    if target_key is not None:
      if len(target_key) > 1:
        target = {key: features.pop(key) for key in target_key}
      else:
        target = features.pop(target_key[0])
      return features, target
    return features
  return _generator_input_fn

答案 2 :(得分:0)

from tensorflow.contrib.learn.python.learn.learn_io import generator_io
import numpy as np

# define generator
def generator():
    for index in range(2):
        yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32}

input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()

请参阅测试用例https://github.com/tensorflow/tensorflow/pull/7045/files