我可以继续训练上一课吗?

时间:2019-10-21 20:03:58

标签: python tensorflow

我正在为文本摘要编写代码。我想使用上一届会议上的最佳模型继续上一届的培训。我正在使用TensorFlow 1.15.0版和google colaboratory执行笔记本。

 learning_rate_decay = 0.95
    min_learning_rate = 0.0005
    display_step = 20 # Check training loss after every 20 batches
    stop_early = 0 
    stop = 3 # If the update loss does not decrease in 3 consecutive update checks, stop training
    per_epoch = 3 # Make 3 update checks per epoch
    update_check = (len(sorted_texts)//batch_size//per_epoch)-1

    update_loss = 0 
    batch_loss = 0
    summary_update_loss = [] # Record the update losses for saving improvements in the model

    checkpoint = "./best_model.ckpt" 
    with tf.Session(graph=train_graph) as sess:
        sess.run(tf.global_variables_initializer())

        # If we want to continue training a previous session
        loader = tf.train.import_meta_graph(checkpoint + '.meta')
        loader.restore(sess, checkpoint)

        for epoch_i in range(1, epochs+1):
            update_loss = 0
            batch_loss = 0
            for batch_i, (summaries_batch, texts_batch, summaries_lengths, texts_lengths) in enumerate(
                    get_batches(sorted_summaries, sorted_texts, batch_size)):
                start_time = time.time()
                _, loss = sess.run(
                    [train_op, cost],
                    {input_data: texts_batch,
                     targets: summaries_batch,
                     lr: learning_rate,
                     summary_length: summaries_lengths,
                     text_length: texts_lengths,
                     keep_prob: keep_probability})

                batch_loss += loss
                update_loss += loss
                end_time = time.time()
                batch_time = end_time - start_time

                if batch_i % display_step == 0 and batch_i > 0:
                    print('Epoch {:>3}/{} Batch {:>4}/{} - Loss: {:>6.3f}, Seconds: {:>4.2f}'
                          .format(epoch_i,
                                  epochs, 
                                  batch_i, 
                                  len(sorted_texts) // batch_size, 
                                  batch_loss / display_step, 
                                  batch_time*display_step))
                    batch_loss = 0

                if batch_i % update_check == 0 and batch_i > 0:
                    print("Average loss for this update:", round(update_loss/update_check,3))
                    summary_update_loss.append(update_loss)

                    # If the update loss is at a new minimum, save the model
                    if update_loss <= min(summary_update_loss):
                        print('New Record!') 
                        stop_early = 0
                        saver = tf.train.Saver() 
                        saver.save(sess, checkpoint)

                    else:
                        print("No Improvement.")
                        stop_early += 1
                        if stop_early == stop:
                            break
                    update_loss = 0


            # Reduce learning rate, but not below its minimum value
            learning_rate *= learning_rate_decay
            if learning_rate < min_learning_rate:
                learning_rate = min_learning_rate

            if stop_early == stop:
                print("Stopping Training.")
                break

在执行这段代码时,我在输出中得到此错误:

INFO:tensorflow:Restoring parameters from ./best_model.ckpt
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1119             subfeed_t = self.graph.as_graph_element(
-> 1120                 subfeed, allow_tensor=True, allow_operation=False)
   1121           except Exception as e:

4 frames
ValueError: Tensor Tensor("input:0", shape=(?, ?), dtype=int32) is not an element of this graph.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1121           except Exception as e:
   1122             raise TypeError('Cannot interpret feed_dict key as Tensor: ' +
-> 1123                             e.args[0])
   1124 
   1125           if isinstance(subfeed_val, ops.Tensor):

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input:0", shape=(?, ?), dtype=int32) is not an element of this graph.

0 个答案:

没有答案