如何从张量流中的检查点还原复杂对象以通过还原检查点对新数据进行预测

时间:2019-06-26 07:22:38

标签: python tensorflow machine-learning

我正在研究Tensorflow模型。在这里,我创建了一个自定义类,用于构建模型。下面是培训代码。

    sess = tf.Session(config=session_conf)
    with sess.as_default():
        model = EntityAttentionLSTM(
            sequence_length=train_x.shape[1],
            num_classes=train_y.shape[1],
            vocab_size=len(vocab_processor.vocabulary_),
            embedding_size=FLAGS.embedding_size,
            pos_vocab_size=len(pos_vocab_processor.vocabulary_),
            pos_embedding_size=FLAGS.pos_embedding_size,
            hidden_size=FLAGS.hidden_size,
            num_heads=FLAGS.num_heads,
            attention_size=FLAGS.attention_size,
            use_elmo=(FLAGS.embeddings == 'elmo'),
            l2_reg_lambda=FLAGS.l2_reg_lambda)

            for train_batch in train_batches:
                train_bx, train_by, train_btxt, train_be1, train_be2, 
                train_bp1, train_bp2 = zip(*train_batch)
                feed_dict = {
                    model.input_x: train_bx,
                    model.input_y: train_by,
                    model.input_text: train_btxt,
                    model.input_e1: train_be1,
                    model.input_e2: train_be2,
                    model.input_p1: train_bp1,
                    model.input_p2: train_bp2,
                    model.emb_dropout_keep_prob: 
                    FLAGS.emb_dropout_keep_prob,
                    model.rnn_dropout_keep_prob: 
                    FLAGS.rnn_dropout_keep_prob,
                    model.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, summaries, loss, accuracy = sess.run(
                [train_op, global_step, train_summary_op, model.loss, model.accuracy], feed_dict)
            train_summary_writer.add_summary(summaries, step)

这一切正常,但是如果我必须加载检查点并恢复模型,那么预测逻辑将如何工作?如何还原一个类的对象?请帮助

0 个答案:

没有答案