我发现运行CIFAR10分布式培训的准确度非常低。即使在2个P2.8X(每台机器上有8个Tesla K80 GPU)上运行1M步,我也看到2.39的损失。我在每台机器上有一个ps,每个GPU有一个worker(总共16个worker),并且batch_size为8。
成功培训后,使用cifar10_eval的validataion数据集的准确度为0.010
我使用tensorflow tutorial中的模型运行cifar10培训的分布式版本。我使用distributed tensorflow的代码示例来运行它的分布式模式。代码的分布式版本如下。 此代码中出现了什么问题以及如何解决此问题以获得更高的准确性?
用于运行ps和的命令与
类似CUDA_VISIBLE_DEVICES='0' python cifar10_multi_machine_train.py --batch_size 8
--data_dir=./cifar10_data --train_dir=./train_logs --ps_hosts=host1:2222,host2:2222
--worker_hosts=host1:2230,host1:2231,host1:2232,host1:2233,host1:2234,host1:2235,
host1:2236,host1:2237,host2:2230,host2:2231,host2:2232,host2:2233,
host2:2234,host2:2235,host2:2236,host2:2237
--job_name=worker --task_index=0
分发的培训代码
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.contrib.framework.get_or_create_global_step()
# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()
# Build inference Graph.
logits = cifar10.inference(images)
# Build the portion of the Graph calculating the losses. Note that we will
# assemble the total_loss using a custom function below.
loss = cifar10.loss(logits, labels)
train_op = cifar10.train(loss,global_step)
# The StopAtStepHook handles stopping after running given steps.
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.num_steps), _LoggerHook()]
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir=FLAGS.train_dir,
save_checkpoint_secs=60,
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# mon_sess.run handles AbortedError in case of preempted PS.
mon_sess.run(train_op)