我尝试使用MirroredStartegy和tf.Estimator进行多gpus培训。第一次尝试是在估算工具tf.train.init_from_chekpoint
中使用model_fn
,如下所示
def model_fn(features, labels, mode, params):
.....
tf.train.init_from_checkpoint(params['resnet_checkpoint'], {'/': 'resnet50/'})
....
这会引发以下错误
.../tensorflow/contrib/distribute/python/values.py", line 285, in _get_update_device
"Use DistributionStrategy.update() to modify a MirroredVariable.")
下一次尝试是使用tf.estimator.WarmStartSetting
ws = tf.estimator.WarmStartSettings(
ckpt_to_initialize_from=params['resnet_checkpoint'],
vars_to_warm_start='resnet50.*',
var_name_to_prev_var_name=var_name_to_prev_var_name
)
session_config = tf.ConfigProto(allow_soft_placement=True)
if FLAGS.num_gpus == 0:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
elif FLAGS.num_gpus == 1:
distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
distribution = tf.contrib.distribute.MirroredStrategy(
num_gpus=FLAGS.num_gpus
)
run_config = tf.estimator.RunConfig(train_distribute=distribution,
session_config=session_config)
estimator = tf.estimator.Estimator(
model_fn=model_function,
params=params,
config=run_config,
model_dir=FLAGS.model_dir,
warm_start_from=ws
)
同样,这会引发错误
TypeError: var MUST be one of the following: a Variable, list of Variable or PartitionedVariable, but is <class 'tensorflow.contrib.distribute.python.values.MirroredVariable'>
要解决这两种方法之一的想法吗?
答案 0 :(得分:0)
遗憾的是,MirroredStrategy尚未支持使用您尝试过的2种机制从检查点恢复。我已经提交了一个github问题来跟踪这个https://github.com/tensorflow/tensorflow/issues/19958。请按照此问题获取进展。