如何正确保存spacy中的textcat模型,以便可以将其导入以供以后使用

时间:2019-07-23 13:21:13

标签: python-3.x machine-learning nlp spacy

我正在尝试以随机方式训练textcat模型,以便可以预测给定段落的文本类别。

培训花费了一些时间,但被找到了。我也可以用它来进行预测。但是在nlp.to_disk()之后,我无法再次加载模型。

到目前为止我所做的:

nlp_ger.begin_training()
TRAINING_DATA = pickle.load(open( "../../Data/training_paragraphs.p", "rb" ))
# Loop for x iterations
for itn in range(10):
    # Shuffle the training data
    random.shuffle(TRAINING_DATA)
    losses = {}

    # Batch the examples and iterate over them
    for batch in spacy.util.minibatch(TRAINING_DATA, size=1):
        texts = [nlp_ger(text) for text, entities in batch]
        annotations = [{"cats": entities} for text, entities in batch]
        nlp_ger.update(texts, annotations, losses=losses)
    if itn % 20 == 0:
        print(losses)

# saving model to disk
nlp_ger.to_disk("basic_text_class_model2")

#Loading new Model
test= spacy.load("de_core_news_md")
test=test.from_disk("/basic_text_class_model/")

当我尝试实际加载模型时,出现以下值错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-100-d8792229fbea> in <module>
     1 #test=test.from_disk("/basic_text_class_model/")
----> 2 test2 = test.from_disk("basic_text_class_model2/")

~\Anaconda3\lib\site-packages\spacy\language.py in from_disk(self, path, exclude, disable)
   789             # Convert to list here in case exclude is (default) tuple
   790             exclude = list(exclude) + ["vocab"]
--> 791         util.from_disk(path, deserializers, exclude)
   792         self._path = path
   793         return self

~\Anaconda3\lib\site-packages\spacy\util.py in from_disk(path, readers, exclude)
   628         # Split to support file names like meta.json
   629         if key.split(".")[0] not in exclude:
--> 630             reader(path / key)
   631     return path
   632 

~\Anaconda3\lib\site-packages\spacy\language.py in <lambda>(p, proc)
   785             if not hasattr(proc, "from_disk"):
   786                 continue
--> 787             deserializers[name] = lambda p, proc=proc: proc.from_disk(p, exclude=["vocab"])
   788         if not (path / "vocab").exists() and "vocab" not in exclude:
   789             # Convert to list here in case exclude is (default) tuple

nn_parser.pyx in spacy.syntax.nn_parser.Parser.from_disk()

~\Anaconda3\lib\site-packages\thinc\neural\_classes\model.py in from_bytes(self, bytes_data)
   370                         name = name.decode("utf8")
   371                     dest = getattr(layer, name)
--> 372                     copy_array(dest, param[b"value"])
   373                 i += 1
   374             if hasattr(layer, "_layers"):

~\Anaconda3\lib\site-packages\thinc\neural\util.py in copy_array(dst, src, casting, where)
   122 def copy_array(dst, src, casting="same_kind", where=None):
   123     if isinstance(dst, numpy.ndarray) and isinstance(src, numpy.ndarray):
--> 124         dst[:] = src
   125     elif is_cupy_array(dst):
   126         src = cupy.array(src, copy=False)

ValueError: could not broadcast input array from shape (29,64) into shape (17,64)

如果我在描述/格式化问题时遇到任何误解,请告诉我。由于我对Stackoverflow的使用尚不甚了解。.我相信,还有足够的改进空间。

0 个答案:

没有答案