在tensorflow 1.8中,tf.Estimators
在其MirroredStrategy
中支持RunConfig
进行多GPU训练(如果您使用Dataset API)。但是,这意味着tf.estimator.inputs.numpy_input_fn
不起作用,因为它不使用新的Dataset API。
支持Dataset API的input_fn
必须返回一个tf.Dataset.Dataset
实例。
在tensorflow文档中,存在两种填充tf.Dataset.Dataset
的方法:
tf.Constant
(使用from_tensor_slices
)。这不是我的选择,因为我的numpy数组超出了图形protobuf施加的2GB限制。tf.Placeholder
张量创建数据集,并在运行时提供它们。但这需要将numpy数组作为feed_dict
传递给session.run。我想使用选项2,但是在训练/推断期间似乎没有办法将feed_dict
传递给我的估计器实例。有关如何执行此操作的任何想法?