如何使用StreamingDataFeeder作为contrib.learn.Estimator.fit()' s input_fn?

时间:2016-10-04 14:48:47

标签: tensorflow skflow

我最近开始使用tensorflow.contrib.learn(skflow)库并且非常喜欢它。但是,我在使用Estimator时遇到问题,fit函数使用

  1. XYbatch_size) - 此方法的问题在于它不支持指定时期数和允许任意数据源的规定。
  2. input_fn - 此外,设置时代,它为我提供了更多的培训来源灵活性(在我的情况下直接来自数据库)。
  3. 现在我知道我可以创建读取文件的input_fn,但是,由于我对处理文件不感兴趣,以下函数对我没用 -

    • tf.contrib.learn.read_batch_examples
    • tf.contrib.learn.read_batch_features
    • tf.contrib.learn.read_batch_record_features

    理想情况下,我想使用StreamingDataFeeder作为input_fn。我有什么想法可以实现这个目标吗?

1 个答案:

答案 0 :(得分:0)

当您向StreamingDataFeeder / x / y / fit predict提供evaluate / Estimator的迭代器时,会使用

x = (np.array([i]) for i in xrange(10**10)) # use range for python >=3.0 y = (np.array([i + 1]) for i in xrange(10**10)) lr = tf.contrib.learn.LinearRegressor( feature_columns=[tf.contrib.layers.real_valued_column('')]) # only consumes 1000*10 values from iterators. lr.fit(x, y, steps=1000, batch_size=10)

示例:

input_fn

如果要使用Tensor来提供数据,则需要使用图形操作来读取/处理数据。例如,您可以创建一个C ++操作来生成您的数据(它可以是侦听端口或从数据库Op读取)并转换为#include <ncurses.h> struct TEST_STRUCT { char nCharacter; // Where I want to store variable for printed character short nTestNumber; // Other stuff in struct }; TEST_STRUCT sTestData[] = { { '.', 1 }, // Period { ',', 2 }, // Comma { ACS_VLINE, 1 } // Vertical Line }; int main(void) { initscr(); clear(); for( int n = 0; n < 3; n++) { addch(sTestData[n].nCharacter); // print the characters in the struct } refresh(); endwin(); return 0; } 。主要是这对于从文件中读取数据很有用,但也可以实现其他读者。