使用Keras的Universal Sentence Encoder预测类时出错

时间:2018-12-11 11:54:18

标签: tensorflow keras

带有Flask + Universal Sentence编码器的推断代码是:

 model.load_weights('./universal_model2.h5')
 model._make_predict_function()
 with open('label_encoder1.pickle', 'rb') as handle:
     label_encoder = pickle.load(handle) 
 K.clear_session()
 tf.reset_default_graph()


g = tf.Graph()
with g.as_default():
  text_input = tf.placeholder(dtype=tf.string, shape=[None])
  embed = hub.Module("/tmp/moduleA",trainable=True)
  encoding_tensor = embed(text_input)
  init_op=tf.group([tf.global_variables_initializer(),
  tf.tables_initializer()])


session = tf.Session(graph=g)
K.set_session(session)
session.run(init_op)
app = Flask(__name__)
@app.route('/predict', methods=['GET'])
def my_form_post():
    text = request.args.get('text')
    text = clean_predict(text,stopwords)
    embe = session.run(encoding_tensor, feed_dict={text_input: [text]})
    pred = model.predict(embe)
    predict_logits = pred.argmax(axis=1)
    pre=pred[0][predict_logits[0]]
    classes = str(label_encoder.inverse_transform(predict_logits)[0])

我遇到此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:在图表中找不到在feed_devices或fetch_devices中指定的张量输入_1:0

0 个答案:

没有答案