我训练了张量流模型来对文本部分进行分类。这应该是一个快速的过程,因为它将用于将文本分类到另一个算法。所以我的解决方案是创建一个简单的api来部署模型并按算法要求进行分类。 问题是,与gensim模型等其他模型不同,我无法弄清楚如何加载模型并询问预测。我可以使其工作的方式是在每次调用api时加载模型,但这非常耗时。
这是我用于预测的算法
params, words_index, labels, embedding_mat = load_trained_params(trained_dir)
data = pd.DataFrame(data=example)
x= data[0].apply(lambda x: data_helper.clean_str(x).split(' ')).tolist()
x_ = data_helper.pad_sentences(x, forced_sequence_length=params['sequence_length'])
x_ = map_word_to_index(x_, words_index)
x_test, y_test = np.asarray(x_), None
with tf.Graph().as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
cnn_rnn = TextCNNRNN(
embedding_mat = embedding_mat,
non_static = params['non_static'],
hidden_unit = params['hidden_unit'],
sequence_length = len(x_test[0]),
max_pool_size = params['max_pool_size'],
filter_sizes = map(int, params['filter_sizes'].split(",")),
num_filters = params['num_filters'],
num_classes = len(labels),
embedding_size = params['embedding_dim'],
l2_reg_lambda = params['l2_reg_lambda'])
def real_len(batches):
return [np.ceil(np.argmin(batch + [0]) * 1.0 / params['max_pool_size']) for batch in batches]
def predict_step(x_batch):
feed_dict = {
cnn_rnn.input_x: x_batch,
cnn_rnn.dropout_keep_prob: 1.0,
cnn_rnn.batch_size: len(x_batch),
cnn_rnn.pad: np.zeros([len(x_batch), 1, params['embedding_dim'], 1]),
cnn_rnn.real_len: real_len(x_batch),
}
predictions = sess.run([cnn_rnn.predictions], feed_dict)
return predictions
checkpoint_file = trained_dir + 'best_model'
saver = tf.train.Saver(tf.global_variables())
print("{}.meta".format(checkpoint_file))
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
print(saver)
saver.restore(sess, checkpoint_file)
logging.critical('{} has been loaded'.format(checkpoint_file))
batches = data_helper.batch_iter(list(x_test), params['batch_size'], 1, shuffle=False)
predictions, predict_labels = [], []
for x_batch in batches:
batch_predictions = predict_step(x_batch)[0]
for batch_prediction in batch_predictions:
predictions.append(batch_prediction)
predict_labels.append(labels[batch_prediction])
print(predict_labels)
当我启动api时,我尝试加载模型和图形。问题在于,当我尝试使用图形时,我会给出一个错误,说明我所唱的变量不是图形的一部分。 我使用的框架是烧瓶,但它可以是任何其他。 我是新手,所以我可能会错过一些简单的事情。 感谢您的帮助。