在Tensorflow中的input_fn中生成无限随机训练数据

时间:2017-07-31 10:37:34

标签: python tensorflow

是否可以创建一个input_fn无限生成随机数据以与Tensorflow中的Estimator API一起使用?

这基本上就是我想要的:

def create_input_fn(function_to_generate_one_sample_with_label):
    def _input_fn():
        ### some code ###
        return feature_cols, labels

然后我想将这个函数与Estimator实例一起使用,如下所示:

def data_generator():
    features = ... generate a (random) feature vector ...
    lablel = ... create suitable label ...
    return features, labels

input_fn = create_input_fn(data_generator)
estimator.train(input_fn=input_fn, steps=ANY_NUMBER_OF_STEPS)

重点是能够根据需要训练多个步骤,即时生成所需的训练数据。这是为了模型调整的目的,能够尝试不同复杂程度的不同训练数据,以便我可以了解模型适合训练数据的能力。

修改 正如jkm建议的那样,我尝试使用一个实际的生成器,如下所示:

def create_input_fn(function, batch_size=100):  
    def create_generator():
        while True:
            features = ... generate <batch_size> feature vectors ...
            lablel = ... create <batch_size> labels ...
            yield features, label
    g = create_generator()
    def _input_fn():
        return next(g)
    return _input_fn

我必须添加批量大小才能让它运行。它现在运行,但input_fn仅被调用一次,因此它不会生成任何新数据。它只训练生成的第一个<batch_size>样本。有没有办法告诉估算工具使用提供的input_fn

刷新数据

3 个答案:

答案 0 :(得分:1)

我认为您可以使用最新的Tf数据集API获得所需的行为,您需要tensorflow&gt; = 1.2.0

# Define number of samples and input shape for each iteration
# you can set minval or maxval as per you data distribution and label distributon requirements
 num_samples = [20000,]
 input_shape = [32, 32, 3]
dataset = tf.contrib.data.Dataset.from_tensor_slices((tf.random_normal([num_examples+input_shape]),  tf.random_uniform([num_samples], minval=0, maxval=5)))
# Define batch_size
batch_size = 128
dataset = dataset.batch(batch_size)
# Define iterator
iterator = dataset.make_initializable_iterator()
# Get one batch
next_example, next_label = iterator.get_next()
# calculate loss from the estimator fucntion you are using
estimator_loss = some_estimator(next_example, next_label)
# Set number of Epochs here
num_epochs = 100
for _ in range(num_epochs):
    sess.run(iterator.initializer)
    while True:
        try:
            _loss = sess.run(estimator_loss)
        except tf.errors.OutOfRangeError:
            break

答案 1 :(得分:0)

提示警告 - 我自己与Tensorflow合作,我只是关闭了API的文档。

那就是说 - 如果那里没有问题,你应该能够做你需要的事情。只需将生成器设为a,生成器(生成功能和标签而不是返回它们)并将整个生成器置于无限循环中。例如:

def data_generator():
    while True:
        #do generatey things here
        yield feature, labels

每次调用时,每次生成一次新值时,都可以重复调用此函数。

答案 2 :(得分:0)

请问:您是否执行数据扩充以生成数据?如果是这样,只要您使用tensorflow框架中的随机函数,被调用的input_fn函数就会生成无限数量的随机样本。 (像tf.random_uniform之类的东西,而不是来自numpy的相应方法等。)这对我有用。