张量流,高gpu利用率但训练速度低

时间:2017-07-05 08:45:07

标签: tensorflow tensorflow-gpu

我在文本示例上训练cnn模型时遇到了问题。

  1. 在单GPU上进行训练时,我的GPU利用率非常高,约为97%,但训练速度非常慢。 1000批次需要450s(每批64个例子),因此每个例子为7ms。相比之下,每个示例的分层lstm只需要2~3ms。
  2. 我尝试在GPU集群上部署我的培训进度,但得到了一个奇怪的GPU利用率。我使用了4个GPU,在大多数时间内利用率为0%。我曾尝试将批量大小从64修改为2,然后GPU利用率变为正常,但小批量大小将导致性能低下。所以我想问一下有没有一种有效的方法可以通过使用GPU集群加速培训进度。
  3. (顺便说一句,这些问题只发生在单个输入示例非常大时,例如包含数千个单词的新闻的正文内容。当输入是新闻标题时,GPU集群工作正常)

    输入格式[64(示例/批处理)* 2500(字/例)* 200(嵌入暗))对于我的5层cnn模型而言是否过于大而无法正确训练?

    1.Model定义(改编自https://github.com/dennybritz/cnn-text-classification-tf

    import tensorflow as tf
    import numpy as np
    
    class TextCNN(object):
        """
        A CNN for text classification.
        Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.
        """
    
        # sequence_length : 2500 (words per doc)
        # num_classes : 36
        # vocab_size : 500,000
        # embedding size : 200
        # filter_sizes : [25, 50, 100]
        # num_filters : [32, 64, 128]
        def __init__(
          self, sequence_length, num_classes, vocab_size,
          embedding_size, filter_sizes, num_filters, l2_reg_lambda=0.0):
    
            # Placeholders for input, output and dropout
            self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
            self.input_y = tf.placeholder(tf.float32, [None, num_classes], name="input_y")
            self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
    
            # Keeping track of l2 regularization loss (optional)
            l2_loss = tf.constant(0.0)
    
            # Embedding layer
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
    
            # Create a convolution + maxpool layer for each filter size
            pooled_outputs = []
    
    
            def conv2d(x, W, stride_h, stride_w):
                return tf.nn.conv2d(x, W, strides=[1,stride_h,stride_w,1], padding="VALID")
    
            def max_pool(x, h, w):
                return tf.nn.max_pool(x, ksize=[1,h,w,1], strides=[1,h,w,1], padding="VALID")
    
            def weight_variable(shape):
                initial = tf.truncated_normal(shape, stddev=0.1)
                return tf.Variable(initial)
    
            def bias_variable(shape):
                initial = tf.constant(0.1, shape=shape)
                return tf.Variable(initial)
    
            n_conv1 = num_filters[0]    # 32
            n_conv2 = num_filters[1]    # 64
            n_conv3 = num_filters[2]    # 128
            n_fc1 = 200
            n_fc2 = 200
    
            # filter_sizes : [25, 50, 100]
            for i, filter_size in enumerate(filter_sizes):
                with tf.name_scope("conv-maxpool-%s" % filter_size):
                    # Convolution Layer
                    print '######## conv-maxpool-%s ########', filter_size
                    w_conv1 = weight_variable([filter_size, 40, 1, n_conv1])
                    b_conv1 = bias_variable([n_conv1])
                    f_conv1 = tf.nn.relu(conv2d(self.embedded_chars_expanded, w_conv1, 25, 5) + b_conv1)
                    print 'conv1: ', f_conv1
                    f_pool1 = max_pool(f_conv1, 2, 2)
                    print 'pool1: ', f_pool1
    
                    w_conv2 = weight_variable([3, 3, n_conv1, n_conv2])
                    b_conv2 = bias_variable([n_conv2])
                    f_conv2 = tf.nn.relu(conv2d(f_pool1, w_conv2, 2, 1) + b_conv2)
                    print 'conv2: ', f_conv2
                    f_pool2 = max_pool(f_conv2, 2, 2)
                    print 'pool2: ', f_pool2
    
                    w_conv3 = weight_variable([2, 2, n_conv2, n_conv3])
                    b_conv3 = bias_variable([n_conv3])
                    f_conv3 = tf.nn.relu(conv2d(f_pool2, w_conv3, 1, 1) + b_conv3)
                    print 'conv3: ', f_conv3
                    f_pool3 = max_pool(f_conv3, 2, 2)
                    print 'pool3: ', f_pool3
    
                    f_size_conv3 = 5 * 3
                    f_pool3_flat = tf.reshape(f_pool3, [-1, f_size_conv3 * n_conv3])
    
                    w_fc1 = weight_variable([f_size_conv3 * n_conv3, n_fc1])
                    b_fc1 = bias_variable([n_fc1])
                    f_fc1 = tf.nn.sigmoid(tf.matmul(f_pool3_flat, w_fc1) + b_fc1)
                    print 'f_fc1: ', f_fc1
                    pooled_outputs.append(f_fc1)
    
            i_fc2 = tf.concat(pooled_outputs, 1)
            print i_fc2
            w_fc2 = weight_variable([n_fc1*len(filter_sizes), n_fc2])
            b_fc2 = bias_variable([n_fc2])
            f_fc2 = tf.nn.sigmoid(tf.matmul(i_fc2, w_fc2) + b_fc2)
            print 'f_fc2: ', f_fc2
            # Combine all the pooled features
            num_filters_total = n_fc2
            self.h_pool_flat = tf.reshape(f_fc2, [-1, num_filters_total])
    
            # Add dropout
            with tf.name_scope("dropout"):
                self.h_drop = tf.nn.dropout(self.h_pool_flat, self.dropout_keep_prob)
    
            # Final (unnormalized) scores and predictions
            with tf.name_scope("output"):
                W = tf.get_variable(
                    "W",
                    shape=[num_filters_total, num_classes],
                    initializer=tf.contrib.layers.xavier_initializer())
                b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name="b")
                l2_loss += tf.nn.l2_loss(W)
                l2_loss += tf.nn.l2_loss(b)
                self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name="scores")
                self.predictions = tf.argmax(self.scores, 1, name="predictions")
    
            # CalculateMean cross-entropy loss
            with tf.name_scope("loss"):
                losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.scores+1e-10, labels=self.input_y)
                self.loss = tf.reduce_mean(losses) + l2_reg_lambda * l2_loss
    
            # Accuracy
            with tf.name_scope("accuracy"):
                self.correct_predition = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
                self.correct_num = tf.reduce_sum(tf.cast(self.correct_predition, tf.float32))
                self.accuracy = tf.reduce_mean(tf.cast(self.correct_predition, "float"), name="accuracy")
    

    2.培训进度(单GPU)

    import tensorflow as tf
    import numpy as np
    import os
    import time
    import datetime
    import data_loader_cnn as data_loader
    from tensorflow.contrib import learn
    
    import sys
    sys.path.append('./model_def')
    from cnn_model import TextCNN
    
    # Data loading params
    tf.flags.DEFINE_string("train_path", "/data/train_data.idx", "Data source for the positive data.")
    tf.flags.DEFINE_string("valid_path", "/data/valid_data.idx", "Data source for the validation data.")
    tf.flags.DEFINE_string("ckpt_dir", "runs-cnn", "Directory for checkpoints.")
    tf.flags.DEFINE_integer("class_num", 36, "Number of total classes")
    tf.flags.DEFINE_integer("vocab_size", 500000, "Number of total distinct words")
    tf.flags.DEFINE_integer("document_length", 50, "Max number of sentences in single text")
    tf.flags.DEFINE_integer("sentence_length", 50, "Max number of words in single sentence")
    
    # Model Hyperparameters
    tf.flags.DEFINE_integer("embedding_dim", 200, "Dimensionality of character embedding (default: 128)")
    tf.flags.DEFINE_string("filter_sizes", "25,50,100", "Comma-separated filter sizes (default: '3,4,5')")
    tf.flags.DEFINE_string("num_filters", "32,64,128", "Number of filters per filter size (default: 128)")
    tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
    tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")
    tf.flags.DEFINE_float("lr", 0.1, "Learning rate (default: 0.1)")
    tf.flags.DEFINE_float("lr_decay", 0.5, "Learning rate decay per epoch (default: 0.6)")
    tf.flags.DEFINE_integer("max_decay_epoch", 10, "Max epoch before decay lr (default: 30)")
    tf.flags.DEFINE_integer('max_grad_norm', 5, 'max_grad_norm')
    
    # Training parameters
    tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
    tf.flags.DEFINE_integer("num_epochs", 60, "Number of training epochs (default: 200)")
    tf.flags.DEFINE_integer("evaluate_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
    tf.flags.DEFINE_integer("checkpoint_every", 50000, "Save model after this many steps (default: 100)")
    
    # Misc Parameters
    tf.flags.DEFINE_boolean("allow_growth", True, "Allow memory softly growth")
    tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
    tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
    
    # For distributed
    tf.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
    tf.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")
    tf.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
    tf.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
    tf.flags.DEFINE_integer("issync", 0, "1 for sync and 0 for async")
    
    FLAGS = tf.flags.FLAGS
    FLAGS._parse_flags()
    print("\nParameters:")
    for attr, value in sorted(FLAGS.__flags.items()):
        print("{}={}".format(attr.upper(), value))
    print("")
    
    
    # Training
    # ==================================================
    def main(_):
        with tf.device('/gpu:1'):
            gpu_config = tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement)
            gpu_config.gpu_options.allow_growth = FLAGS.allow_growth
            with tf.Graph().as_default(), tf.Session(config=gpu_config) as sess:
                # Load data
                print("Loading data...")
                document_length = FLAGS.document_length
                sentence_length = FLAGS.sentence_length
                train_data = data_loader.load_data(FLAGS.train_path, document_length, sentence_length, FLAGS.class_num)
                valid_data = data_loader.load_data(FLAGS.valid_path, document_length, sentence_length, FLAGS.class_num)
                batch_num_per_epoch = len(train_data[0]) / FLAGS.batch_size
    
                print 'len train data', len(train_data[0])
                print 'batch_num_per_epoch', batch_num_per_epoch
    
                # Building model
                cnn = TextCNN(
                    sequence_length=document_length*sentence_length,
                    num_classes=FLAGS.class_num,
                    vocab_size=FLAGS.vocab_size,
                    embedding_size=FLAGS.embedding_dim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=list(map(int, FLAGS.num_filters.split(","))),
                    l2_reg_lambda=FLAGS.l2_reg_lambda)
    
                # Define Training procedure
                global_step = tf.Variable(0, name="global_step", trainable=False)
                lr = tf.Variable(0.0, trainable=False)
                new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
                _lr_update = tf.assign(lr, new_lr)
    
                tvars = tf.trainable_variables()
                grads, _ = tf.clip_by_global_norm(tf.gradients(cnn.loss, tvars),
                                              FLAGS.max_grad_norm)
                optimizer = tf.train.GradientDescentOptimizer(lr)
                grads_and_vars = zip(grads, tvars)
                optimizer.apply_gradients(grads_and_vars)
                train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
    
                # Output directory for models and summaries
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.ckpt_dir, timestamp))
                print("Writing to {}\n".format(out_dir))
    
                def train_step(sess, x_batch, y_batch, epoch_index):
                    """
                    A single training step
                    """
    
                    # hard coding
                    if epoch_index < 15:
                        new_lr_temp = 0.1
                    if epoch_index >= 15 and epoch_index < 25:
                        new_lr_temp = 0.01
                    elif epoch_index >= 25 and epoch_index < 40:
                        new_lr_temp = 0.001
                    elif epoch_index >= 40:
                        new_lr_temp = 0.0001
    
                    feed_dict = {
                      cnn.input_x: x_batch,
                      cnn.input_y: y_batch,
                      cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                      new_lr: new_lr_temp
                    }
                    current_lr, _, _, loss, accuracy = sess.run(
                        [lr, _lr_update, train_op, cnn.loss, cnn.accuracy],
                        feed_dict)
                    time_str = datetime.datetime.now().isoformat()
                    return loss, time_str
    
    
                # ====================== dev_step ======================
                def dev_step(sess, x_batch, y_batch, writer=None):
                    """
                    Evaluates model on a dev set
                    """
                    exp = int(max(epoch_index-FLAGS.max_decay_epoch,0)/20)
                    lr_decay = FLAGS.lr_decay ** exp
    
                    feed_dict = {
                      cnn.input_x: x_batch,
                      cnn.input_y: y_batch,
                      cnn.dropout_keep_prob: 1.0,
                      new_lr: FLAGS.lr*lr_decay
                    }
                    _, loss, correct_num = sess.run(
                        [_lr_update, cnn.loss, cnn.correct_num],
                        feed_dict)
                    time_str = datetime.datetime.now().isoformat()
                    return len(x_batch), correct_num
    
    
                # ====================== eval ======================
                def evaluate(sess, valid_data, batch_size):
                    batch_iter = data_loader.batch_iter(valid_data, batch_size)
                    example_num = 0
                    correct_num = 0
                    for valid_x, valid_y in batch_iter:
                        batch_len, batch_corrent = dev_step(sess, valid_x, valid_y)
                        example_num += batch_len
                        correct_num += batch_corrent
                    accuracy = float(correct_num) / example_num
                    return accuracy
    
    
                # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
                checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
                checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                init_op = tf.initialize_all_variables()
                saver = tf.train.Saver(tf.all_variables())
                sess.run(init_op)
    
                # Generate batches
                batch_iter = data_loader.global_batch_iter(
                    train_data, FLAGS.batch_size, FLAGS.num_epochs)
                # Training loop. For each batch...
                current_step = sess.run(global_step)
                print 'current step', current_step
                while current_step < batch_num_per_epoch * FLAGS.num_epochs:
                    current_step = sess.run(global_step)
                    epoch_index = current_step / batch_num_per_epoch
                    if current_step % batch_num_per_epoch == 0:
                        print("Epoch ", epoch_index)
    
                    x_batch, y_batch = next(batch_iter)
                    loss, time_str = train_step(sess, x_batch, y_batch, epoch_index)
    
                    if current_step % FLAGS.evaluate_every == 0:
                        accuracy = evaluate(sess, valid_data, FLAGS.batch_size)
                        print("{}: step {}, loss {:g}, acc {:g}".format(time_str, current_step, loss, accuracy))
    
                    if current_step % FLAGS.checkpoint_every == 0:
                        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))
    
    
    if __name__ == "__main__":
      tf.app.run()
    

    3.平行培训进度

    import tensorflow as tf
    import numpy as np
    import os
    import time
    import datetime
    import data_loader_cnn as data_loader
    from cnn_model import TextCNN
    from tensorflow.contrib import learn
    
    # Data loading params
    tf.flags.DEFINE_string("train_path", "/data/slice/", "Data source for the positive data.")
    tf.flags.DEFINE_string("valid_path", "/data/valid_data.idx", "Data source for the validation data.")
    tf.flags.DEFINE_string("ckpt_dir", "runs-cnn", "Directory for checkpoints.")
    tf.flags.DEFINE_integer("class_num", 36, "Number of total classes")
    tf.flags.DEFINE_integer("vocab_size", 500000, "Number of total distinct words")
    tf.flags.DEFINE_integer("document_length", 50, "Max number of sentences in single text")
    tf.flags.DEFINE_integer("sentence_length", 50, "Max number of words in single sentence")
    
    # Model Hyperparameters
    tf.flags.DEFINE_integer("embedding_dim", 200, "Dimensionality of character embedding (default: 128)")
    tf.flags.DEFINE_string("filter_sizes", "25,50,100", "Comma-separated filter sizes (default: '3,4,5')")
    tf.flags.DEFINE_string("num_filters", "32,64,128", "Number of filters per filter size (default: 128)")
    tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
    tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")
    tf.flags.DEFINE_float("lr", 0.1, "Learning rate (default: 0.1)")
    tf.flags.DEFINE_float("lr_decay", 0.5, "Learning rate decay per epoch (default: 0.6)")
    tf.flags.DEFINE_integer("max_decay_epoch", 10, "Max epoch before decay lr (default: 30)")
    tf.flags.DEFINE_integer('max_grad_norm', 5, 'max_grad_norm')
    
    # Training parameters
    tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
    tf.flags.DEFINE_integer("num_epochs", 60, "Number of training epochs (default: 200)")
    tf.flags.DEFINE_integer("evaluate_every", 1000, "Evaluate model on dev set after this many steps (default: 100)")
    tf.flags.DEFINE_integer("checkpoint_every", 50000, "Save model after this many steps (default: 100)")
    
    # Misc Parameters
    tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
    tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
    
    # For distributed
    tf.flags.DEFINE_string("ps_hosts", "",
                           "Comma-separated list of hostname:port pairs")
    tf.flags.DEFINE_string("worker_hosts", "",
                           "Comma-separated list of hostname:port pairs")
    tf.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
    tf.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
    tf.flags.DEFINE_integer("issync", 0, "1 for sync and 0 for async")
    
    FLAGS = tf.flags.FLAGS
    FLAGS._parse_flags()
    print("\nParameters:")
    for attr, value in sorted(FLAGS.__flags.items()):
        print("{}={}".format(attr.upper(), value))
    print("")
    
    
    # Training
    # ==================================================
    def main(_):
        ps_hosts = FLAGS.ps_hosts.split(",")
        worker_hosts = FLAGS.worker_hosts.split(",")
        cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
        server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
        issync = FLAGS.issync
        if FLAGS.job_name == "ps":
            server.join()
        elif FLAGS.job_name == "worker":
            with tf.device(tf.train.replica_device_setter(
                              worker_device="/job:worker/task:%d" % FLAGS.task_index,
                              cluster=cluster)):
                # Load data
                print("Loading data...")
                document_length = FLAGS.document_length
                sentence_length = FLAGS.sentence_length
                train_path = FLAGS.train_path + str(FLAGS.task_index)
                train_data = data_loader.load_data(train_path, document_length, sentence_length, FLAGS.class_num)
                valid_data = data_loader.load_data(FLAGS.valid_path, document_length, sentence_length, FLAGS.class_num)
                batch_num_per_epoch = len(train_data[0]) / FLAGS.batch_size
    
                print 'len train data', len(train_data[0])
                print 'batch_num_per_epoch', batch_num_per_epoch
    
                # Building model
                cnn = TextCNN(
                    sequence_length=document_length*sentence_length,
                    num_classes=FLAGS.class_num,
                    vocab_size=FLAGS.vocab_size,
                    embedding_size=FLAGS.embedding_dim,
                    filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
                    num_filters=list(map(int, FLAGS.num_filters.split(","))),
                    l2_reg_lambda=FLAGS.l2_reg_lambda)
    
                # Define Training procedure
                global_step = tf.Variable(0, name="global_step", trainable=False)
                lr = tf.Variable(0.0, trainable=False)
                new_lr = tf.placeholder(tf.float32, shape=[], name="new_learning_rate")
                _lr_update = tf.assign(lr, new_lr)
    
                tvars = tf.trainable_variables()
                grads, _ = tf.clip_by_global_norm(tf.gradients(cnn.loss, tvars),
                                              FLAGS.max_grad_norm)
                optimizer = tf.train.GradientDescentOptimizer(lr)
                grads_and_vars = zip(grads, tvars)
                optimizer.apply_gradients(grads_and_vars)
                train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
    
                # Output directory for models and summaries
                timestamp = str(int(time.time()))
                out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.ckpt_dir, timestamp))
                print("Writing to {}\n".format(out_dir))
    
                def train_step(sess, x_batch, y_batch, epoch_index):
                    """
                    A single training step
                    """
    
                    # hard coding
                    if epoch_index < 15:
                        new_lr_temp = 0.1
                    if epoch_index >= 15 and epoch_index < 25:
                        new_lr_temp = 0.01
                    elif epoch_index >= 25 and epoch_index < 40:
                        new_lr_temp = 0.001
                    elif epoch_index >= 40:
                        new_lr_temp = 0.0001
    
                    feed_dict = {
                      cnn.input_x: x_batch,
                      cnn.input_y: y_batch,
                      cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
                      new_lr: new_lr_temp
                    }
                    current_lr, _, _, loss, accuracy = sess.run(
                        [lr, _lr_update, train_op, cnn.loss, cnn.accuracy],
                        feed_dict)
                    time_str = datetime.datetime.now().isoformat()
                    return loss, time_str
    
    
                # ====================== dev_step ======================
                def dev_step(sess, x_batch, y_batch, writer=None):
                    """
                    Evaluates model on a dev set
                    """
                    exp = int(max(epoch_index-FLAGS.max_decay_epoch,0)/20)
                    lr_decay = FLAGS.lr_decay ** exp
    
                    feed_dict = {
                      cnn.input_x: x_batch,
                      cnn.input_y: y_batch,
                      cnn.dropout_keep_prob: 1.0,
                      new_lr: FLAGS.lr*lr_decay
                    }
                    _, loss, correct_num = sess.run(
                        [_lr_update, cnn.loss, cnn.correct_num],
                        feed_dict)
                    time_str = datetime.datetime.now().isoformat()
                    return len(x_batch), correct_num
    
    
                # ====================== eval ======================
                def evaluate(sess, valid_data, batch_size):
                    batch_iter = data_loader.batch_iter(valid_data, batch_size)
                    example_num = 0
                    correct_num = 0
                    for valid_x, valid_y in batch_iter:
                        batch_len, batch_corrent = dev_step(sess, valid_x, valid_y)
                        example_num += batch_len
                        correct_num += batch_corrent
                    accuracy = float(correct_num) / example_num
                    return accuracy
    
    
                # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
                checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
                checkpoint_prefix = os.path.join(checkpoint_dir, "model")
                if not os.path.exists(checkpoint_dir):
                    os.makedirs(checkpoint_dir)
                init_op = tf.initialize_all_variables()
                saver = tf.train.Saver(tf.all_variables())
    
                ################################################################
                sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                                        logdir=checkpoint_prefix,
                                        init_op=init_op,
                                        summary_op=None,
                                        saver=saver,
                                        global_step=global_step,
                                        save_model_secs=60)
    
                with sv.prepare_or_wait_for_session(server.target) as sess:
                    # Generate batches
                    batch_iter = data_loader.global_batch_iter(
                        train_data, FLAGS.batch_size, FLAGS.num_epochs)
                    # Training loop. For each batch...
                    current_step = sess.run(global_step)
                    while current_step < batch_num_per_epoch * FLAGS.num_epochs:
                        current_step = sess.run(global_step)
                        epoch_index = current_step / batch_num_per_epoch
                        if current_step % batch_num_per_epoch == 0:
                            print("Epoch ", epoch_index)
    
                        x_batch, y_batch = next(batch_iter)
                        loss, time_str = train_step(sess, x_batch, y_batch, epoch_index)
    
                        if current_step % FLAGS.evaluate_every == 0:
                            accuracy = evaluate(sess, valid_data, FLAGS.batch_size)
                            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, current_step, loss, accuracy))
    
                        if current_step % FLAGS.checkpoint_every == 0:
                            path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                            print("Saved model checkpoint to {}\n".format(path))
                sv.stop()
    
    
    if __name__ == "__main__":
      tf.app.run()
    

0 个答案:

没有答案