我在TF上使用高级估算器:
estim = tf.contrib.learn.Estimator(...)
estim.fit ( some_input )
如果 some_input 有x
,y
和batch_size
,则代码会运行,但会显示警告;所以我尝试使用input_fn
,并设法通过此x
发送y
,input_fn
,但不发送batch_size
。没有找到任何例子。
是否有人可以分享使用input_fn
作为estim.fit
/ estim.evaluate
输入的简单示例,并使用batch_size
?
我必须使用tf.train.batch
吗?如果是这样,它如何合并到更高级别的实现(tf.layers
) - 我不知道图形的tf.Graph()或会话?
以下是我收到的警告:
警告:tensorflow:来自/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/learn/python/learn/monitors.py:657:调用evaluate
不推荐使用带有y的(来自tensorflow.contrib.learn.python.learn.estimators.estimator),并将在2016-12-01之后删除。
更新说明: 通过迁移,Estimator与Scikit Learn界面分离 单独的类SKCompat。参数x,y和batch_size仅为 在SKCompat类中,Estimator只接受input_fn。
转换示例:
est = Estimator(...) - > est = SKCompat(Estimator(...))
答案 0 :(得分:4)
link provided in Roi's own comment确实非常有帮助。由于我一直在努力解决同样的问题,我想总结上面链接提供的答案作为参考:
def batched_input_fn(dataset_x, dataset_y, batch_size):
def _input_fn():
all_x = tf.constant(dataset_x, shape=dataset_x.shape, dtype=tf.float32)
all_y = tf.constant(dataset_y, shape=dataset_y.shape, dtype=tf.float32)
sliced_input = tf.train.slice_input_producer([all_x, all_y])
return tf.train.batch(sliced_input, batch_size=batch_size)
return _input_fn
然后可以像这个例子一样使用它(使用TensorFlow v1.1):
model = CustomModel(FLAGS.learning_rate)
estimator= tf.estimator.Estimator(model_fn=model.build(), params=model.params())
estimator.train(input_fn=batched_input_fn(
train.features,
train.labels,
FLAGS.batch_size),
steps=FLAGS.train_steps)
不幸的是,与手动进纸(使用TensorFlows低级API)或使用整个数据集train.shape[0] == batch_size
并且不使用{{1}相比,这种方法慢10倍 }和train.sliced_input_producer()
。至少在我的机器上(仅限CPU)。我真的很想知道为什么这种方法如此缓慢。有什么想法吗?
<强>编辑:强>
我可以使用train.batch()
&gt;加快速度1作为num_threads
的参数。在具有2个CPU的VM上,与默认的train.batch()
相比,我可以使用此批处理机制将性能提高一倍。但仍然比手动喂食<强>慢5倍。
但是在本机系统或使用输入管道的所有CPU核心和用于模型计算的GPU的系统上,结果可能会有所不同。如果有人可以在评论中发表他的结果,那将会很棒。