Tensorflow:我的模型在第一次迭代后没有注册

时间:2018-08-30 08:21:46

标签: python tensorflow

我上了一个可用的here教程,我试图在我的数据集上运行它,它能够编译并开始训练,但是这里我得到了:

enter image description here

该模型似乎没有在每次迭代中保存。 我尝试了100个纪元,但没有任何改变,它给出了第一次迭代的输出。 您知道会有什么问题吗? (我知道代码很抱歉)

def train(model, epochs, log_string):
    '''Train the RNN'''

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # Used to determine when to stop the training early
        valid_loss_summary = []

        # Keep track of which batch iteration is being trained
        iteration = 0

        print()
        print("Training Model: {}".format(log_string))

        train_writer = tf.summary.FileWriter('./logs/3/train/{}'.format(log_string), sess.graph)
        valid_writer = tf.summary.FileWriter('./logs/3/valid/{}'.format(log_string))

        for e in range(epochs):
            state = sess.run(model.initial_state)

            # Record progress with each epoch
            train_loss = []
            train_acc = []
            val_acc = []
            val_loss = []

            with tqdm(total=len(x_train)) as pbar:
                for _, (x, y) in enumerate(get_batches(x_train, y_train, batch_size), 1):
                    feed = {model.inputs: x,
                            model.labels: y[:, None],
                            model.keep_prob: dropout,
                            model.initial_state: state}
                    summary, loss, acc, state, _ = sess.run([model.merged, 
                                                             model.cost, 
                                                             model.accuracy, 
                                                             model.final_state, 
                                                             model.optimizer], 
                                                            feed_dict=feed)                

                    # Record the loss and accuracy of each training batch
                    train_loss.append(loss)
                    train_acc.append(acc)

                    # Record the progress of training
                    train_writer.add_summary(summary, iteration)

                    iteration += 1
                    pbar.update(batch_size)

            avg_train_loss = np.mean(train_loss)
            avg_train_acc = np.mean(train_acc) 

            val_state = sess.run(model.initial_state)
            with tqdm(total=len(x_valid)) as pbar:
                for x, y in get_batches(x_valid, y_valid, batch_size):
                    feed = {model.inputs: x,
                            model.labels: y[:, None],
                            model.keep_prob: 1,
                            model.initial_state: val_state}
                    summary, batch_loss, batch_acc, val_state = sess.run([model.merged, 
                                                                          model.cost, 
                                                                          model.accuracy, 
                                                                          model.final_state], 
                                                                         feed_dict=feed)

                    # Record the validation loss and accuracy of each epoch
                    val_loss.append(batch_loss)
                    val_acc.append(batch_acc)
                    pbar.update(batch_size)

            # Average the validation loss and accuracy of each epoch
            avg_valid_loss = np.mean(val_loss)    
            avg_valid_acc = np.mean(val_acc)
            valid_loss_summary.append(avg_valid_loss)

            # Record the validation data's progress
            valid_writer.add_summary(summary, iteration)

            # Print the progress of each epoch

            print("Epoch: {}/{}".format(e, epochs),
                   "Train Loss: {:.3f}".format(avg_train_loss),
                   "Train Acc: {:.3f}".format(avg_train_acc),
                   "Valid Loss: {:.3f}".format(avg_valid_loss),
                   "Valid Acc: {:.3f}".format(avg_valid_acc))

             # Stop training if the validation loss does not decrease after 3 epochs
            if avg_valid_loss > min(valid_loss_summary):
                 print("No Improvement.")
                 stop_early += 1
                 if stop_early == 3:
                     break   

             # Reset stop_early if the validation loss finds a new low
             # Save a checkpoint of the model
            else:
                 print("New Record!")
                 stop_early = 0
         checkpoint = "sauvegarde/controverse_{}.ckpt".format(log_string)
         saver.save(sess,checkpoint)

非常感谢您的回答:)

0 个答案:

没有答案