我正在使用simpletransformers.classification训练Bert Moder对一些文本输入进行分类。这是我的代码。
from simpletransformers.classification import ClassificationModel
import torch
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from simpletransformers.classification import ClassificationModel
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification
import parallelTestModule
# Lets import the csv file in pandas dataframe first
train_df = pd.read_csv('D:\\7allV03Small.csv', encoding='utf-8', header=None, names=['cat', 'text'])
# Check the df
print(train_df.head())
# unique categories
print(train_df.cat.unique())
print("Total categories",len(train_df.cat.unique()))
# convert string labels to integers
train_df['labels'] = pd.factorize(train_df.cat)[0]
print(train_df.head())
# Let's create a train and test set
from sklearn.model_selection import train_test_split
train, test = train_test_split(train_df, test_size=0.2, random_state=42)
print('Eğitim veri seti boyutu : ' + str(train.shape), ' Test eğitim seti : ' + str(test.shape))
if __name__ == "__main__":
from multiprocessing import freeze_support
model = ClassificationModel('bert', 'bert-base-multilingual-uncased', use_cuda=False, num_labels=8, args={'reprocess_input_data': True, 'overwrite_output_dir': True, 'num_train_epochs': 1,'train_batch_size':1})
freeze_support()
# Now lets fine tune bert with the train set
model.train_model(train)
一切看起来都还不错,并且开始训练。但是在训练结束时,会出现如下错误。
Traceback (most recent call last):
File "c:/Users/arslanom/Desktop/text/try.py", line 45, in <module>
model.train_model(train)
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\simpletransformers\classification\classification_model.py", line 269, in train_model
**kwargs,
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\simpletransformers\classification\classification_model.py", line 544, in train
self._save_model(output_dir_current, optimizer, scheduler, model=model)
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\simpletransformers\classification\classification_model.py", line 1113, in _save_model
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\torch\serialization.py", line 209, in save
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\torch\serialization.py", line 134, in _with_file_like
return body(f)
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\torch\serialization.py", line 209, in <lambda>
return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
File "C:\Users\arslanom\AppData\Roaming\Python\Python36\site-packages\torch\serialization.py", line 282, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object 'get_linear_schedule_with_warmup.<locals>.lr_lambda'
类似此问题的声音与worker_count有关,因为它使用多线程运行。但是我找不到任何解决方案。
操作系统:Windows 10
RAM:16 Gb