在由numpy数组提供的Estimator中将input_fn与tf.Dataset结合使用

时间:2018-06-22 00:37:04

标签: tensorflow tensorflow-datasets tensorflow-estimator

在tensorflow 1.8中,tf.Estimators在其MirroredStrategy中支持RunConfig进行多GPU训练(如果您使用Dataset API)。但是,这意味着tf.estimator.inputs.numpy_input_fn不起作用,因为它不使用新的Dataset API。

支持Dataset API的input_fn必须返回一个tf.Dataset.Dataset实例。

在tensorflow文档中,存在两种填充tf.Dataset.Dataset的方法:

  1. 将numpy数组嵌入为tf.Constant(使用from_tensor_slices)。这不是我的选择,因为我的numpy数组超出了图形protobuf施加的2GB限制。
  2. tf.Placeholder张量创建数据集,并在运行时提供它们。但这需要将numpy数组作为feed_dict传递给session.run。

我想使用选项2,但是在训练/推断期间似乎没有办法将feed_dict传递给我的估计器实例。有关如何执行此操作的任何想法?

0 个答案:

没有答案