带有Flask + Universal Sentence编码器的推断代码是:
model.load_weights('./universal_model2.h5')
model._make_predict_function()
with open('label_encoder1.pickle', 'rb') as handle:
label_encoder = pickle.load(handle)
K.clear_session()
tf.reset_default_graph()
g = tf.Graph()
with g.as_default():
text_input = tf.placeholder(dtype=tf.string, shape=[None])
embed = hub.Module("/tmp/moduleA",trainable=True)
encoding_tensor = embed(text_input)
init_op=tf.group([tf.global_variables_initializer(),
tf.tables_initializer()])
session = tf.Session(graph=g)
K.set_session(session)
session.run(init_op)
app = Flask(__name__)
@app.route('/predict', methods=['GET'])
def my_form_post():
text = request.args.get('text')
text = clean_predict(text,stopwords)
embe = session.run(encoding_tensor, feed_dict={text_input: [text]})
pred = model.predict(embe)
predict_logits = pred.argmax(axis=1)
pre=pred[0][predict_logits[0]]
classes = str(label_encoder.inverse_transform(predict_logits)[0])
我遇到此错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError:在图表中找不到在feed_devices或fetch_devices中指定的张量输入_1:0