热烈启动distribute.MirroredStrategy和tf.Estimator

时间:2018-06-08 10:00:38

标签: tensorflow

我尝试使用MirroredStartegy和tf.Estimator进行多gpus培训。第一次尝试是在估算工具tf.train.init_from_chekpoint中使用model_fn,如下所示

def model_fn(features, labels, mode, params):

    .....

   tf.train.init_from_checkpoint(params['resnet_checkpoint'], {'/': 'resnet50/'})

   ....

这会引发以下错误

.../tensorflow/contrib/distribute/python/values.py", line 285, in _get_update_device
    "Use DistributionStrategy.update() to modify a MirroredVariable.")

下一次尝试是使用tf.estimator.WarmStartSetting

ws = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=params['resnet_checkpoint'],
        vars_to_warm_start='resnet50.*',
        var_name_to_prev_var_name=var_name_to_prev_var_name
    )

session_config = tf.ConfigProto(allow_soft_placement=True)

if FLAGS.num_gpus == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
elif FLAGS.num_gpus == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus
        )
run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        session_config=session_config)

estimator = tf.estimator.Estimator(
        model_fn=model_function,
        params=params,
        config=run_config,
        model_dir=FLAGS.model_dir,
        warm_start_from=ws
    )

同样,这会引发错误

TypeError: var MUST be one of the following: a Variable, list of Variable or PartitionedVariable, but is <class 'tensorflow.contrib.distribute.python.values.MirroredVariable'>

要解决这两种方法之一的想法吗?

1 个答案:

答案 0 :(得分:0)

遗憾的是,MirroredStrategy尚未支持使用您尝试过的2种机制从检查点恢复。我已经提交了一个github问题来跟踪这个https://github.com/tensorflow/tensorflow/issues/19958。请按照此问题获取进展。