使用以下命令保存TF模型:
tf.train.Saver
创建的检查点文件:
checkpoint
model.ckpt-297190.index
model.ckpt-297190.meta
model.ckpt-297190.data-00000-of-00001
我能够成功保存和还原图形(train.py和test.py),并且一旦恢复了模型,就可以使用tf.train.Saver API生成预测
现有代码:
train.py
model = Model(reversed_dict, article_max_len, summary_max_len, args)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
test.py
with tf.Session() as sess:
print('Loading saved model...')
model = Model(
reversed_dict, article_max_len, summary_max_len, args, forward_only=True)
saver = tf.train.Saver(tf.global_variables())
checkpoint = tf.train.get_checkpoint_state('./saved_model/')
saver.restore(sess, checkpoint.model_checkpoint_path)
问题:
我想创建一个Flask API以使用现有的检查点文件进行服务。 每次加载模型时,生成预测时,每次加载图形时响应都很慢。
如何在主代码中导入模型,然后具有一个函数(摘要)来生成可以访问同一会话的预测?
summarizer.py
import tensorflow as tf
import pickle
from model import Model
from utils import build_dict, build_dataset, batch_iter
with open('args.pickle', 'rb') as f:
args = pickle.load(f)
print('Loading dictionary...')
word_dict, reversed_dict, article_max_len, summary_max_len = build_dict('test',
args.toy)
print('Loading validation dataset...')
valid_x = build_dataset('test', word_dict, article_max_len, summary_max_len,
args.toy)
valid_x_len = [len([y for y in x if y != 0]) for x in valid_x]
with tf.Session() as sess:
print('Loading saved model...')
model = Model(reversed_dict, article_max_len, summary_max_len, args,
forward_only=True)
saver = tf.train.Saver(tf.global_variables())
checkpoint = tf.train.get_checkpoint_state('./saved_model/')
saver.restore(sess, checkpoint.model_checkpoint_path)
batches = batch_iter(valid_x, [0] * len(valid_x), args.batch_size, 1)
print('Writing summaries to result.txt...')
for batch_x, _ in batches:
batch_x_len = [len([y for y in x if y != 0]) for x in batch_x]
valid_feed_dict = {
model.batch_size: len(batch_x),
model.X: batch_x,
model.X_len: batch_x_len,
}
prediction = sess.run(model.prediction, feed_dict=valid_feed_dict)
prediction_output = [[reversed_dict[y] for y in x] for x in prediction[:, 0, :]]
with open('result.txt', 'a') as f:
for line in prediction_output:
summary = []
for word in line:
if word == '</s>':
break
if word not in summary:
summary.append(word)
print(' '.join(summary), file=f)
print('Summaries are saved to result.txt...')
烧瓶代码:
import summarizer
from flask import Flask, request, Response, json
app = Flask(__name__)
@app.route('/')
def index():
return Response('TensorFlow text summarizer')
@app.route('/summary', methods=['POST'])
def process_text():
"""Process text."""
try:
print(request.is_json)
content = request.get_json()
text = content.get('text')
summary = summarizer.summary(text)
return app.response_class(
response=json.dumps(summary),
status=200,
mimetype='application/json'
)
except Exception as e:
print('POST /summary error: %e' % e)
return e