通用句子编码器中的Tensorflow会话错误

时间:2020-05-14 20:30:40

标签: python tensorflow tensorflow2.0 tensorflow-serving sentence-similarity

我为通用句子编码器提供了以下代码,一旦将模型加载到烧瓶api 中并尝试点击它,它就会产生以下错误(检查如下):

'''

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
model_2 = hub.load(module_url)
print ("module %s loaded" % module_url)

def embed(input):
    return model_2(input)


def universalModel(messages):
    accuracy = []
    similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
    similarity_message_encodings = embed(similarity_input_placeholder)
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        message_embeddings_ = session.run(similarity_message_encodings, feed_dict={similarity_input_placeholder: messages})

        corr = np.inner(message_embeddings_, message_embeddings_)
        accuracy.append(corr[0,1])
    # print(corr[0,1])
    return "%.2f" % accuracy[0]

'''

在烧瓶API中使用模型时出现的以下错误: tensorflow.python.framework.errors_impl.InvalidArgumentError:图形无效,包含一个带有1个节点的循环,包括:StatefulPartitionedCall 尽管此代码可以正常运行,但在colab笔记本中。

我正在使用Tensorflow版本2.2.0。

1 个答案:

答案 0 :(得分:1)

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

这两行旨在使tensorflow 2.x变为tensorflow1.x。

对于Tensorflow 1.x,这是在与flask,django等一起使用时的常见问题。 您必须定义一个图和会话进行推断,

将tensorflow导入为tf 导入tensorflow_hub作为中心

# Create graph and finalize (finalizing optional but recommended).
g = tf.Graph()
with g.as_default():
  # We will be feeding 1D tensors of text into the graph.
  text_input = tf.placeholder(dtype=tf.string, shape=[None])
  embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
  embedded_text = embed(text_input)
  init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
g.finalize()

# Create session and initialize.
session = tf.Session(graph=g)
session.run(init_op)

输入请求可以通过

处理
result = session.run(embedded_text, feed_dict={text_input: ["Hello world"]})

有关详细信息 https://www.tensorflow.org/hub/common_issues

对于tensorflow 2.x会话和图形不是必需的。

import tensorflow as tf
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
model_2 = hub.load(module_url)
print ("module %s loaded" % module_url)

def embed(input):
    return model_2(input)
#pass messages as list
def universalModel(messages):
    accuracy = []
    message_embeddings_= embed(messages)
    corr = np.inner(message_embeddings_, message_embeddings_)
    accuracy.append(corr[0,1])
    # print(corr[0,1])
    return "%.2f" % accuracy[0]