我想使用BERT计算句子之间的语义相似度。我在github上找到了以下代码,该代码已经针对语义相似性进行了微调:
from semantic_text_similarity.models import WebBertSimilarity
from semantic_text_similarity.models import ClinicalBertSimilarity
web_model = WebBertSimilarity(device='cpu', batch_size=10)
它下载100%,并给我以下错误(这是最后一行):
TypeError: init_weights() takes 1 positional argument but 2 were given
我试图阅读有关此错误的信息,但我不明白在哪里给出的2个位置参数在哪里而不是在哪里。
我希望您能找到任何提示。
谢谢!
-------------------------编辑问题-------------------- ----------
这是整个错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-73-97be4030b59e> in <module>()
----> 1 web_model = WebBertSimilarity(device='cpu', batch_size=10) #defaults to GPU prediction
/anaconda3/lib/python3.6/site-packages/semantic_text_similarity/models/bert/web_similarity.py in __init__(self, device, batch_size, model_name)
6 def __init__(self, device='cuda', batch_size=10, model_name="web-bert-similarity"):
7 model_path = get_model_path(model_name)
----> 8 super().__init__(device=device, batch_size=batch_size, bert_model_path=model_path)
/anaconda3/lib/python3.6/site-packages/semantic_text_similarity/models/bert/similarity.py in __init__(self, args, device, bert_model_path, batch_size, learning_rate, weight_decay, additional_features)
80 config.pretrained_config_archive_map['additional_features'] = additional_features
81
---> 82 self.regressor_net = BertSimilarityRegressor.from_pretrained(self.args['bert_model_path'], config=config)
83 self.optimizer = torch.optim.Adam(
84 self.regressor_net.parameters(),
/anaconda3/lib/python3.6/site-packages/pytorch_transformers/modeling_utils.py in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
534 if hasattr(model, 'tie_weights'):
535 model.tie_weights() # make sure word embedding weights are still tied
--> 536
537 # Set model in evaluation mode to desactivate DropOut modules by default
538 model.eval()
/anaconda3/lib/python3.6/site-packages/semantic_text_similarity/models/bert/similarity.py in __init__(self, bert_model_config)
25 )
26
---> 27 self.apply(self.init_weights)
28
29
/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in apply(self, fn)
291 """
292 for module in self.children():
--> 293 module.apply(fn)
294 fn(self)
295 return self
/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in apply(self, fn)
291 """
292 for module in self.children():
--> 293 module.apply(fn)
294 fn(self)
295 return self
/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in apply(self, fn)
291 """
292 for module in self.children():
--> 293 module.apply(fn)
294 fn(self)
295 return self
/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in apply(self, fn)
292 for module in self.children():
293 module.apply(fn)
--> 294 fn(self)
295 return self
296
TypeError: init_weights() takes 1 positional argument but 2 were given