我目前正在使用tensorflow的官方模型分支中的resnet训练代码,并且我在从Estimator调用时遇到了与input_fn的设备放置有关的问题。
当从Estimator类中的_call_input_fn调用input_fn时,默认设备位置始终为CPU - [r1.8 source]。
with ops.device('/cpu:0'):
return input_fn(**kwargs)
虽然官方文档建议将输入数据管道放置在CPU上,以便GPU仅用于训练,但可能存在需要在数据加载阶段执行前向推断的情况。例如,在对抗性攻击期间,通过计算相对于输入的渐变来扰乱输入图像。
使用Estimators时,有没有办法覆盖输入数据加载的显式CPU设备位置?