如何使用已保存的模型在Tensorflow中运行会话以生成预测

时间:2018-09-23 18:07:52

标签: python tensorflow machine-learning

  • 版本: Tensorflow 1.8
  • 型号: Seq2Seq模型

使用以下命令保存TF模型:

  tf.train.Saver

创建的检查点文件:

  checkpoint
  model.ckpt-297190.index
  model.ckpt-297190.meta
  model.ckpt-297190.data-00000-of-00001

我能够成功保存和还原图形(train.pytest.py),并且一旦恢复了模型,就可以使用tf.train.Saver API生成预测

现有代码:

train.py

  model = Model(reversed_dict, article_max_len, summary_max_len, args)
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(tf.global_variables())

test.py

with tf.Session() as sess:
  print('Loading saved model...')
  model = Model(
      reversed_dict, article_max_len, summary_max_len, args, forward_only=True)
  saver = tf.train.Saver(tf.global_variables())
  checkpoint = tf.train.get_checkpoint_state('./saved_model/')
  saver.restore(sess, checkpoint.model_checkpoint_path)

问题:

我想创建一个Flask API以使用现有的检查点文件进行服务。 每次加载模型时,生成预测时,每次加载图形时响应都很慢。

如何在主代码中导入模型,然后具有一个函数(摘要)来生成可以访问同一会话的预测?

summarizer.py

import tensorflow as tf
import pickle

from model import Model
from utils import build_dict, build_dataset, batch_iter

with open('args.pickle', 'rb') as f:
    args = pickle.load(f)

print('Loading dictionary...')
word_dict, reversed_dict, article_max_len, summary_max_len = build_dict('test',
                                                                        args.toy)
print('Loading validation dataset...')
valid_x = build_dataset('test', word_dict, article_max_len, summary_max_len,
                        args.toy)
valid_x_len = [len([y for y in x if y != 0]) for x in valid_x]

with tf.Session() as sess:
    print('Loading saved model...')
    model = Model(reversed_dict, article_max_len, summary_max_len, args,
                  forward_only=True)
    saver = tf.train.Saver(tf.global_variables())
    checkpoint = tf.train.get_checkpoint_state('./saved_model/')
    saver.restore(sess, checkpoint.model_checkpoint_path)

    batches = batch_iter(valid_x, [0] * len(valid_x), args.batch_size, 1)

    print('Writing summaries to result.txt...')
    for batch_x, _ in batches:
        batch_x_len = [len([y for y in x if y != 0]) for x in batch_x]
        valid_feed_dict = {
            model.batch_size: len(batch_x),
            model.X: batch_x,
            model.X_len: batch_x_len,
        }
        prediction = sess.run(model.prediction, feed_dict=valid_feed_dict)
        prediction_output = [[reversed_dict[y] for y in x] for x in prediction[:, 0, :]]

        with open('result.txt', 'a') as f:
            for line in prediction_output:
                summary = []
                for word in line:
                    if word == '</s>':
                        break
                    if word not in summary:
                        summary.append(word)
                print(' '.join(summary), file=f)

    print('Summaries are saved to result.txt...')

烧瓶代码:

import summarizer

from flask import Flask, request, Response, json

app = Flask(__name__)


@app.route('/')
def index():
    return Response('TensorFlow text summarizer')


@app.route('/summary', methods=['POST'])
def process_text():
    """Process text."""
    try:
        print(request.is_json)
        content = request.get_json()
        text = content.get('text')
        summary = summarizer.summary(text)
        return app.response_class(
            response=json.dumps(summary),
            status=200,
            mimetype='application/json'
        )

    except Exception as e:
        print('POST /summary error: %e' % e)
        return e

0 个答案:

没有答案