Tensorflow CIFAR10分布式培训,准确性

时间:2017-03-10 00:06:28

标签: tensorflow

我发现运行CIFAR10分布式培训的准确度非常低。即使在2个P2.8X(每台机器上有8个Tesla K80 GPU)上运行1M步,我也看到2.39的损失。我在每台机器上有一个ps,每个GPU有一个worker(总共16个worker),并且batch_size为8。

成功培训后,使用cifar10_eval的validataion数据集的准确度为0.010

我使用tensorflow tutorial中的模型运行cifar10培训的分布式版本。我使用distributed tensorflow的代码示例来运行它的分布式模式。代码的分布式版本如下。 此代码中出现了什么问题以及如何解决此问题以获得更高的准确性?

用于运行ps和的

命令与

类似
CUDA_VISIBLE_DEVICES='0' python cifar10_multi_machine_train.py --batch_size 8
--data_dir=./cifar10_data --train_dir=./train_logs --ps_hosts=host1:2222,host2:2222
--worker_hosts=host1:2230,host1:2231,host1:2232,host1:2233,host1:2234,host1:2235,
host1:2236,host1:2237,host2:2230,host2:2231,host2:2232,host2:2233,
host2:2234,host2:2235,host2:2236,host2:2237
--job_name=worker --task_index=0

分发的培训代码

if FLAGS.job_name == "ps":
    server.join()
elif FLAGS.job_name == "worker":

    # Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(
        worker_device="/job:worker/task:%d" % FLAGS.task_index,
        cluster=cluster)):

        global_step = tf.contrib.framework.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()

        # Build inference Graph.
        logits = cifar10.inference(images)

        # Build the portion of the Graph calculating the losses. Note that we will
        # assemble the total_loss using a custom function below.
        loss = cifar10.loss(logits, labels)

        train_op = cifar10.train(loss,global_step)

    # The StopAtStepHook handles stopping after running given steps.
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.num_steps), _LoggerHook()]

    # The MonitoredTrainingSession takes care of session initialization,
    # restoring from a checkpoint, saving to a checkpoint, and closing when done
    # or an error occurs.
    with tf.train.MonitoredTrainingSession(master=server.target,
                                            is_chief=(FLAGS.task_index == 0),
                                            checkpoint_dir=FLAGS.train_dir,
                                            save_checkpoint_secs=60,
                                            hooks=hooks) as mon_sess:
        while not mon_sess.should_stop():
            # Run a training step asynchronously.
            # See `tf.train.SyncReplicasOptimizer` for additional details on how to
            # perform *synchronous* training.
            # mon_sess.run handles AbortedError in case of preempted PS.
            mon_sess.run(train_op)

0 个答案:

没有答案