为estimator.predict

时间:2019-07-11 08:43:12

标签: python tensorflow tensorflow-serving tensorflow-estimator

我有一个基于预训练的估计器的Tensorflow模型,我正试图将其用于流数据。

下面是代码

import json
from bent.data.util import bert_tokenizer, convert_bert_docs, IDMapper
from http.server import BaseHTTPRequestHandler, HTTPServer
from bent.data.inputs import *
from time import sleep


class StreamingSamples(object):
    def __init__(self, FLAGS):
        self.FLAGS = FLAGS
        self.sample = None
        self.empty = True
        self.tokenizer = bert_tokenizer(FLAGS.word_embedding)

    def new_sample(self, sample):
        self.sample = sample
        self.token_ids, self.tokens = convert_bert_docs(docs=[self.sample], encoder=self.tokenizer)
        self.empty = False

    def gen(self):
        while True:
            if not self.empty:
                self.empty = True
                for doc_id, segment_id, token_seg, token_id_seg in docs_segment_generator(tokens=self.tokens,
                                                                                          token_ids=self.token_ids,
                                                                                          doc_ids=range(
                                                                                              len(self.tokens)),
                                                                                          max_seq_size=self.FLAGS.max_seq_size):
                    feature = {"doc_id": doc_id,
                               "segment_id": segment_id,
                               "token": token_seg,
                               "token_id": token_id_seg,
                               }
                    label = {}
                    yield feature, label
            else:
                print("sleep")
                sleep(0.5)


def py_input_fn(gen, batch_size, max_seq_size) -> tf.data.Dataset:
    ds = tf.data.Dataset.from_generator(
        generator=gen,
        output_types=get_types(mode=tfest.estimator.ModeKeys.PREDICT, label_to_feature=False),
        output_shapes=get_shapes(
            mode=tfest.estimator.ModeKeys.PREDICT,
            max_seq_size=None,
            label_to_feature=False
        )
    )

    ds = process_dataset(ds=ds,
                         mode=tfest.estimator.ModeKeys.PREDICT,
                         batch_size=batch_size,
                         max_seq_size=max_seq_size,
                         epoch=1,
                         label_to_feature=False)
    return ds


class GetHandler(BaseHTTPRequestHandler):

    def do_POST(self):
        content_length = int(self.headers['Content-Length'])
        post_data = self.rfile.read(content_length)
        self.send_response(200)
        self.end_headers()
        response = nnprocessing.predict(*read_json(post_data))
        print("response in server.py ned:\n", response)
        self.wfile.write(bytes(json.dumps(response), "utf-8"))
        return


def read_json(post_data):
    data = json.loads(post_data.decode("utf-8"))
    text = data["text"]
    spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]]
    return text, spans


class NNProcessing(object):
    def __init__(self, est, FLAGS):
        self.est = est
        self.FLAGS = FLAGS
        self.first_run = True
        self.streaming_samples = StreamingSamples(FLAGS)

    def predict(self, text, given_spans):
        self.streaming_samples.new_sample(text)
        if self.first_run:
            self.predictions = self.est.predict(
                input_fn=lambda: py_input_fn(
                    self.streaming_samples.gen,
                    self.FLAGS.batch_size,
                    self.FLAGS.max_seq_size
                ),
                predict_keys=["doc_id",
                              "segment_id",
                              "token",
                              "token_mask",
                              "head_mask",
                              "iob_pred",
                              "iob_class_pred",
                              "iob_class_label",
                              "entity_pred",
                              "entity_score",
                              "entity_label",
                              "entity_embedding_pred"],
                yield_single_examples=True,
            )
            self.first_run = False

        r = next(self.predictions)
        print(r)
        return r
#         for row in next(self.results):
#             logging.info(f"doc_id: {row['doc_id']}, segment_id: {row['segment_id']}")
#             for i in range(FLAGS.max_seq_size):
#                 if row['token_mask'][i]:
#                     logging.info(
#                         f"{row['token'][i].decode('utf8')}\t"
#                         f"{row['iob_class_pred'][i]}\t"
#                         f"{row['entity_pred'][i]}\t"
#                     )


def start_server(est, FLAGS):
    global nnprocessing
    nnprocessing = NNProcessing(est, FLAGS)
    server = HTTPServer(('localhost', 5555), GetHandler)
    print('Starting server at http://localhost:5555')
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        exit(0)

问题在于,在第一个输入进入后,估计器将不返回任何内容,并继续打印“ sleep”语句。我想原因是由于生成器功能。

只要有任何请求到达我的服务器,我都需要估算器来返回结果。

请指导我正确的方向。没有 tensorflow-serving API 可以做到吗?

TIA

0 个答案:

没有答案