将input_fn用于tf.contrib.learn.Estimator

时间:2017-03-14 17:19:07

标签: python tensorflow

我在TF上使用高级估算器:

estim = tf.contrib.learn.Estimator(...)
estim.fit ( some_input )

如果 some_input xybatch_size,则代码会运行,但会显示警告;所以我尝试使用input_fn,并设法通过此x发送yinput_fn,但不发送batch_size。没有找到任何例子。

是否有人可以分享使用input_fn作为estim.fit / estim.evaluate输入的简单示例,并使用batch_size

我必须使用tf.train.batch吗?如果是这样,它如何合并到更高级别的实现(tf.layers) - 我不知道图形的tf.Graph()或会话?

以下是我收到的警告:

  

警告:tensorflow:来自/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657:调用evaluate

     不推荐使用带有y的

(来自tensorflow.contrib.learn.python.learn.estimators.estimator),并将在2016-12-01之后删除。

     

更新说明:   通过迁移,Estimator与Scikit Learn界面分离   单独的类SKCompat。参数x,y和batch_size仅为   在SKCompat类中,Estimator只接受input_fn。

     

转换示例:

     

est = Estimator(...) - > est = SKCompat(Estimator(...))

1 个答案:

答案 0 :(得分:4)

link provided in Roi's own comment确实非常有帮助。由于我一直在努力解决同样的问题,我想总结上面链接提供的答案作为参考:

def batched_input_fn(dataset_x, dataset_y, batch_size):
    def _input_fn():
        all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
        all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
        sliced_input = tf.train.slice_input_producer([all_x, all_y])
        return tf.train.batch(sliced_input, batch_size=batch_size)
    return _input_fn

然后可以像这个例子一样使用它(使用TensorFlow v1.1):

model = CustomModel(FLAGS.learning_rate)
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params())

estimator.train(input_fn=batched_input_fn(
       train.features, 
       train.labels,
       FLAGS.batch_size),
    steps=FLAGS.train_steps)

不幸的是,与手动进纸(使用TensorFlows低级API)或使用整个数据集train.shape[0] == batch_size并且不使用{{1}相比,这种方法慢10倍 }和train.sliced_input_producer()。至少在我的机器上(仅限CPU)。我真的很想知道为什么这种方法如此缓慢。有什么想法吗?

<强>编辑:

我可以使用train.batch()&gt;加快速度1作为num_threads的参数。在具有2个CPU的VM上,与默认的train.batch()相比,我可以使用此批处理机制将性能提高一倍。但仍然比手动喂食<强>慢5倍。 但是在本机系统或使用输入管道的所有CPU核心和用于模型计算的GPU的系统上,结果可能会有所不同。如果有人可以在评论中发表他的结果,那将会很棒。