Tensorflow - 如何使用Estimator将input_fn用于GPU

时间:2018-05-04 07:34:39

标签: python tensorflow deep-learning tensorflow-estimator

我目前正在使用tensorflow的官方模型分支中的resnet训练代码,并且我在从Estimator调用时遇到了与input_fn的设备放置有关的问题。

当从Estimator类中的_call_input_fn调用input_fn时,默认设备位置始终为CPU - [r1.8 source]

with ops.device('/cpu:0'):
  return input_fn(**kwargs)

虽然官方文档建议将输入数据管道放置在CPU上,以便GPU仅用于训练,但可能存在需要在数据加载阶段执行前向推断的情况。例如,在对抗性攻击期间,通过计算相对于输入的渐变来扰乱输入图像。

使用Estimators时,有没有办法覆盖输入数据加载的显式CPU设备位置?

0 个答案:

没有答案