Tensorflow分类器的Api

时间:2017-07-04 09:55:36

标签: python tensorflow

我训练了张量流模型来对文本部分进行分类。这应该是一个快速的过程,因为它将用于将文本分类到另一个算法。所以我的解决方案是创建一个简单的api来部署模型并按算法要求进行分类。 问题是,与gensim模型等其他模型不同,我无法弄清楚如何加载模型并询问预测。我可以使其工作的方式是在每次调用api时加载模型,但这非常耗时。

这是我用于预测的算法

params, words_index, labels, embedding_mat =  load_trained_params(trained_dir)
data = pd.DataFrame(data=example)

x= data[0].apply(lambda x: data_helper.clean_str(x).split(' ')).tolist()


x_ = data_helper.pad_sentences(x, forced_sequence_length=params['sequence_length'])
x_ = map_word_to_index(x_, words_index)

x_test, y_test = np.asarray(x_), None

with tf.Graph().as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
    cnn_rnn = TextCNNRNN(
        embedding_mat = embedding_mat,
        non_static = params['non_static'],
        hidden_unit = params['hidden_unit'],
        sequence_length = len(x_test[0]),
        max_pool_size = params['max_pool_size'],
        filter_sizes = map(int, params['filter_sizes'].split(",")),
        num_filters = params['num_filters'],
        num_classes = len(labels),
        embedding_size = params['embedding_dim'],
        l2_reg_lambda = params['l2_reg_lambda'])

    def real_len(batches):
        return [np.ceil(np.argmin(batch + [0]) * 1.0 / params['max_pool_size']) for batch in batches]

    def predict_step(x_batch):
        feed_dict = {
            cnn_rnn.input_x: x_batch,
            cnn_rnn.dropout_keep_prob: 1.0,
            cnn_rnn.batch_size: len(x_batch),
            cnn_rnn.pad: np.zeros([len(x_batch), 1, params['embedding_dim'], 1]),
            cnn_rnn.real_len: real_len(x_batch),
        }
        predictions = sess.run([cnn_rnn.predictions], feed_dict)
        return predictions

    checkpoint_file = trained_dir + 'best_model'
    saver = tf.train.Saver(tf.global_variables())
    print("{}.meta".format(checkpoint_file))
    saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
    print(saver)
    saver.restore(sess, checkpoint_file)
    logging.critical('{} has been loaded'.format(checkpoint_file))

    batches = data_helper.batch_iter(list(x_test), params['batch_size'], 1, shuffle=False)

    predictions, predict_labels = [], []
    for x_batch in batches:
        batch_predictions = predict_step(x_batch)[0]
        for batch_prediction in batch_predictions:
            predictions.append(batch_prediction)
            predict_labels.append(labels[batch_prediction])

    print(predict_labels)

当我启动api时,我尝试加载模型和图形。问题在于,当我尝试使用图形时,我会给出一个错误,说明我所唱的变量不是图形的一部分。 我使用的框架是烧瓶,但它可以是任何其他。 我是新手,所以我可能会错过一些简单的事情。 感谢您的帮助。

0 个答案:

没有答案