如何使用来自其Github和CheckPoint文件的预训练过的Tensorflow模型运行推理

时间:2019-04-13 18:47:24

标签: python tensorflow

我想从此github中的模型中收集文本嵌入

https://github.com/dmis-lab/biobert

在安装过程中,仅显示

  

要使用BioBERT,我们需要预先训练的BioBERT重量,您可以   从Naver GitHub存储库下载BioBERT的预训练权重。   确保指定用于您的预训练砝码的版本   作品。另外,请注意,此存储库基于BERT存储库   由Google提供。

     

所有微调实验均在单个TITAN Xp上进行   具有12GB RAM的GPU机器。该代码已经过Python2测试   和Python3(我们将Python2用于实验)。你可能想要   安装Java以使用BioASQ的官方评估脚本。看到   其他细节请参见requirements.txt。

我可以下载他们的检查点文件并使用

加载它
with tf.Session(graph=graph) as session:

   saver.restore(session, 'BioBert.ckpt' )

并使用类似的方法安装他们的github

!test -d bioBert_repo|| git clone https://github.com/dmis-lab/biobert bioBert_repo

但是如何从文本输入中获取嵌入。说明说它是基于BERT的,但是对于BERT,我们要做的就是导入tf.hub模型

bert_module = hub.Module(
 "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
 trainable=False)

然后 将标记化的句子放入其中

bert_embedding= bert_module(inputs=tokenized_sentence, signature="tokens", as_dict=True)[
       "pooled_output"
   ]

我猜想有一种类似的方法,我可以安装github并加载权重,但似乎找不到它。

1 个答案:

答案 0 :(得分:0)

您应该从extract_features.py中看到示例。 我想BIOBert不使用tf.hub。