使用plotly在具有实时绘图的Web应用程序中部署Keras Model

时间:2019-05-10 20:43:59

标签: python flask plotly plotly-dash

我已经使用keras开发了一种深度学习模型,并且已成功将其部署在Flash应用程序中。中包含/ predict端点,另一个应用程序可以使用该端点查询数据,然后将预测返回给它。该flask应用程序只是作为控制台应用程序运行,但除此之外,我还想弹出一个网页,以显示输入数据的实时散点图。我能够使用bokeh库并将其结合到我的flask应用程序中来执行此操作。现在,当应用启动时,将打开该图的网页,并将其配置为定期在flask应用中命中新的“ /数据”端点,以获取最新数据并更新2d散点图。这似乎工作可靠。

现在是我发帖子的原因。我真的很想通过将2d散点图更改为3d散点图来改善这一点。不幸的是,散景不提供开箱即用的3D散点图。这使我无所适从。我知道可以进行3D散点图。我正在寻找示例代码,或有关如何执行此操作的建议。

我应该保留flask应用程序的现有结构,但以某种方式将其与plotly结合以代替散景吗?这意味着必须创建网页并定期访问我的flask应用程序的/ data点。

我应该放弃烧瓶,改为使用破折号应用吗?如果是这样,那么我将需要在破折号中部署一个keras模型。并设置/ predict和/ data值。

下面是我现有的有关烧瓶和散景的代码。有人对我应该如何做甚至更好做任何建议吗?有人可以向我展示一些适用于我要尝试的示例代码。

import numpy as np
import tensorflow as tf
import keras
from keras.models import model_from_json
from flask import Flask, jsonify, make_response, request

from bokeh.plotting import figure, show
from bokeh.models import AjaxDataSource, CustomJS, Range1d

# Bokeh related code

adapter = CustomJS(code="""
    const result = {x: [], y: []}
    const pts = cb_data.response.points
    for (i=0; i<pts.length; i++) {
        result.x.push(pts[i][0])
        result.y.push(pts[i][1])
    }
    return result
""")

source = AjaxDataSource(data_url='http://10.61.226.215:443/data',
                        polling_interval=200, adapter=adapter)

p = figure(plot_height=700, plot_width=1400, x_axis_label='Frequency', y_axis_label='Phase', background_fill_color="lightgrey",
           title="Scatter Plot of TOI")
p.x_range = Range1d(0, 49)
p.y_range = Range1d(-2**15, 2**15)
p.circle('x', 'y', source=source, color='red', size=10)

# Flask related code

app = Flask(__name__)

def get_model():
    global model
    global g

    json_path = 'RNN_LSTM_128_Drop0p2.json'
    h5_path = "RNN-010-0.983732-0.997119.h5"
    g = tf.Graph()
    with g.as_default():
        # Pull in the model we want to test
        json_file = open(json_path, 'r')
        loaded_model_json = json_file.read()
        json_file.close()
        model = model_from_json(loaded_model_json)
        # load weights into new model
        model.load_weights(h5_path)
        print("Loaded model from disk")
        # Compile the loaded model
        model.compile(loss=keras.losses.categorical_crossentropy,
                      optimizer=keras.optimizers.Adadelta(),
                      metrics=['accuracy'])
        print("Compiled Model")
    print(" * Recurrent Neural Network Trained Model Loaded Successfully")


def RNN_scale_input(X):
    N_samp = X.shape[0]
    N_t_samp = X.shape[1]
    n_feat = X.shape[2]

    X_scale = np.zeros(shape=(N_samp, N_t_samp, n_feat), dtype=np.float32)
    for h in range(N_samp):
        # Scale the channel index to be between -1 and 1
        X_scale[h][:, 0] = (X[h][:, 0] - 25) / 25
        # Scale the phase value to be between -1 and 1
        X_scale[h][:, 1] = X[h][:, 1] / 2**15
    return X_scale


# Define function to pre-process phase data
def pre_process_data(data, num_rows, num_cols):
    # Strip the opening and closing brackets from the input data string
    data = data.strip('[]')
    # Convert string into a list of floating point numbers
    try:
        data = [float(h) for h in data.split(',')]
    except:
        print("***ERROR***")
        print(data)
    data_np = np.asarray(data).reshape(1, num_rows, num_cols)
    rnn_structure = RNN_scale_input(data_np)
    return rnn_structure[0]

def pre_process_plot_data(data, num_rows, num_cols):
    # Strip the opening and closing brackets from the input data string
    data = data.strip('[]')
    # Convert string into a list of floating point numbers
    try:
        list_data = [int(k) for k in data.split(',')]
    except:
        print("***ERROR***")
        print(data)
    plot_data_np = np.asarray(list_data).reshape(num_rows, num_cols)
    return plot_data_np


N_t, n_features = 64, 2
print(" * Loading Neural Network Trained Model...")
get_model()

# Define Tag of Interest to plot
TOI = "307401320416C2054B6E99D7"

def crossdomain(f):
    def wrapped_function(*args, **kwargs):
        resp = make_response(f(*args, **kwargs))
        h = resp.headers
        h['Access-Control-Allow-Origin'] = '*'
        h['Access-Control-Allow-Methods'] = "GET, OPTIONS, POST"
        h['Access-Control-Max-Age'] = str(21600)
        requested_headers = request.headers.get('Access-Control-Request-Headers')
        if requested_headers:
            h['Access-Control-Allow-Headers'] = requested_headers
        return resp
    return wrapped_function


x = [0]*N_t
y = [0]*N_t

@app.route("/predict", methods=['POST'])
def predict():
    global x
    global y
    message = request.get_json(force=True)
    TagId = message['TagId']
    PhaseData = message['PhaseInput']
    model_input = np.zeros(shape=(1, N_t, n_features), dtype='float32')
    model_input[0] = pre_process_data(PhaseData, N_t, n_features)
    # If TagId == TOI, plot the data
    if TagId == TOI:
        # Update data in plot
        plot_data = pre_process_plot_data(PhaseData, N_t, n_features)
        x1 = list(plot_data[:, 0])
        x = [int(x1[i]) for i in range(N_t)]
        y1 = list(plot_data[:, 1])
        y = [int(y1[i]) for i in range(N_t)]

    # Make prediction with model
    with g.as_default():
        model_output = model.predict(model_input)
    # Take the second element (location at index (0, 1)) as the float prediction
    pred_val = model_output[0, 1]
    prediction = pred_val.tolist()
    # Construct response to send back to client app
    response = {
        'tag': TagId,
        'prediction': prediction
    }
    return jsonify(response)


@app.route('/data', methods=['GET', 'OPTIONS', 'POST'])
@crossdomain
def data():
    global x
    global y
    return jsonify(points=list(zip(x, y)))

# show and run
show(p)
# app.run(port=443)

0 个答案:

没有答案