加载预训练的通用句子编码器的问题

时间:2020-04-29 00:26:51

标签: python tensorflow nlp tensorflow2.0 pre-trained-model

我在为NLP任务加载经过预训练的模块时遇到问题,并且该问题是由于我想的tf迁移引起的。 Tensorflow网站说,如果正确给出签名变量,则可能会解决该问题。您能帮我纠正此代码吗?

TypeError:“ AutoTrackable”对象不可调用

[代码]

import tensorflow_hub as hub
# enabling the pretrained model for trainig our custom model using tensorflow hub
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
embed = hub.load(module_url)

# creating a method for embedding and will using method for every input layer 
def UniversalEmbedding(x):
    return embed(tf.squeeze(tf.cast(x, tf.string)), signature='default', as_dict=True)["default"]

1 个答案:

答案 0 :(得分:0)

我是该领域的新手,但我可以与您分享解决问题的方法:

  • 使用以下命令检查您的Tensorflow版本:

    print(tensorflow。版本

  • 您可以在运行导入tensorflow之前运行以下代码来更改Tensorflow版本

%tensorflow_version 1.x

  • 我的TF版本是1.x,我正在使用hub.load(url)并收到“自动跟踪”错误消息,当我将其替换为hub.Module(url)时,它解决了我的问题

  • 这些链接可能对您有所帮助!如果您不了解它们

https://tfhub.dev/google/elmo/3 https://www.tensorflow.org/hub/common_issues

祝你好运