大多数教程都关注整个训练数据集适合内存的情况。但是,我有一个迭代器,它充当(特征,标签)-tuples的无限流(在运行中便宜地创建它们)。
为tensorflow estimator实现input_fn
时,我可以从迭代器返回一个实例
def input_fn():
(feature_batch, label_batch) = next(it)
return tf.constant(feature_batch), tf.constant(label_batch)
或input_fn
必须在每次通话时返回相同的(功能,标签) - 元组吗?
此外,在训练期间多次调用此函数,因为我希望它类似于以下伪代码:
for i in range(max_iter):
learn_op(input_fn())
答案 0 :(得分:2)
input_fn
的参数在整个训练中使用,但函数本身被调用一次。因此,如tutorial中所述,创建一个复杂的input_fn
不仅仅是返回一个常量数组也不是那么简单。
Tensorflow为numpy和panda数组提供了两个非平凡的input_fn
示例,但它们从内存中的数组开始,因此这对您的问题没有帮助。
您还可以按照上面的链接查看他们的代码,看看他们如何实现有效的非平凡input_fn
,但您可能会发现它需要更多您想要的代码。
如果您愿意使用Tensorflow的低级别界面,那么IMHO更简单,更灵活。有一个tutorial可以满足大多数需求,建议的解决方案很容易实现。
特别是,如果您已经有一个按照问题中描述的方式返回数据的迭代器,那么使用占位符(上一个链接中的“Feeding”部分)应该很简单。
答案 1 :(得分:2)
我找到了一个将generator
转换为input_fn
的拉取请求:
https://github.com/tensorflow/tensorflow/pull/7045/files
相关部分是
def _generator_input_fn():
"""generator input function."""
queue = feeding_functions.enqueue_data(
x,
queue_capacity,
shuffle=shuffle,
num_threads=num_threads,
enqueue_size=batch_size,
num_epochs=num_epochs)
features = (queue.dequeue_many(batch_size) if num_epochs is None
else queue.dequeue_up_to(batch_size))
if not isinstance(features, list):
features = [features]
features = dict(zip(input_keys, features))
if target_key is not None:
if len(target_key) > 1:
target = {key: features.pop(key) for key in target_key}
else:
target = features.pop(target_key[0])
return features, target
return features
return _generator_input_fn
答案 2 :(得分:0)
from tensorflow.contrib.learn.python.learn.learn_io import generator_io
import numpy as np
# define generator
def generator():
for index in range(2):
yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32}
input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()
请参阅测试用例https://github.com/tensorflow/tensorflow/pull/7045/files