从new_toxic_multilabel.ipynb
运行sample notebook
时,出现以下命令错误:
databunch = BertDataBunch(args['data_dir'], LABEL_PATH, args.model_name, train_file='train.csv', val_file='val.csv',
test_data='test.csv',
text_col="NOTES", label_col=label_cols,
batch_size_per_gpu=args['train_batch_size'], max_seq_length=args['max_seq_length'],
multi_gpu=args.multi_gpu, multi_label=True, model_type=args.model_type)
这是我的label_col:
label_cols = ['CLASS_1','CLASS_2','CLASS_3','CLASS_4','CLASS_5','CLASS_6','CLASS_7','CLASS_8','CLASS_9','CLASS_10','CLASS_11','CLASS_12','CLASS_13','CLASS_14','CLASS_15','CLASS_16','CLASS_17','CLASS_E','CLASS_V']
my labels.csv包含与上述相同的类,但一个接一个地列出: labels.csv:
'CLASS_1'
'CLASS_2'
'CLASS_3'
'CLASS_4'
'CLASS_5'
'CLASS_6'
'CLASS_7'
'CLASS_8'
'CLASS_9'
'CLASS_10'
'CLASS_11'
'CLASS_12'
'CLASS_13'
'CLASS_14'
'CLASS_15'
'CLASS_16'
'CLASS_17'
'CLASS_E'
'CLASS_V'
这是回溯:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-13-c5a2ac3a5e99> in <module>
3 text_col="NOTES", label_col=label_cols,
4 batch_size_per_gpu=args['train_batch_size'], max_seq_length=args['max_seq_length'],
----> 5 multi_gpu=args.multi_gpu, multi_label=True, model_type=args.model_type)
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/fast_bert/data_cls.py in __init__(self, data_dir, label_dir, tokenizer, train_file, val_file, test_data, label_file, text_col, label_col, batch_size_per_gpu, max_seq_length, multi_gpu, multi_label, backend, model_type, logger, clear_cache, no_cache)
352 if os.path.exists(cached_features_file) == False or self.no_cache == True:
353 train_examples = processor.get_train_examples(
--> 354 train_file, text_col=text_col, label_col=label_col)
355
356 train_dataset = self.get_dataset_from_examples(
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/fast_bert/data_cls.py in get_train_examples(self, filename, text_col, label_col, size)
230 data_df = pd.read_csv(os.path.join(self.data_dir, filename))
231
--> 232 return self._create_examples(data_df, "train", text_col=text_col, label_col=label_col)
233 else:
234 data_df = pd.read_csv(os.path.join(self.data_dir, filename))
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/fast_bert/data_cls.py in _create_examples(self, df, set_type, text_col, label_col)
286 else:
287 return list(df.apply(lambda row: InputExample(guid=row.index, text_a=row[text_col],
--> 288 label=_get_labels(row, label_col)), axis=1))
289
290
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/frame.py in apply(self, func, axis, broadcast, raw, reduce, result_type, args, **kwds)
6926 kwds=kwds,
6927 )
-> 6928 return op.get_result()
6929
6930 def applymap(self, func):
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/apply.py in get_result(self)
184 return self.apply_raw()
185
--> 186 return self.apply_standard()
187
188 def apply_empty_result(self):
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/apply.py in apply_standard(self)
290
291 # compute the result using the series generator
--> 292 self.apply_series_generator()
293
294 # wrap results
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/apply.py in apply_series_generator(self)
319 try:
320 for i, v in enumerate(series_gen):
--> 321 results[i] = self.f(v)
322 keys.append(v.name)
323 except Exception as e:
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/fast_bert/data_cls.py in <lambda>(row)
286 else:
287 return list(df.apply(lambda row: InputExample(guid=row.index, text_a=row[text_col],
--> 288 label=_get_labels(row, label_col)), axis=1))
289
290
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/fast_bert/data_cls.py in _get_labels(row, label_col)
273 def _get_labels(row, label_col):
274 if isinstance(label_col, list):
--> 275 return list(row[label_col])
276 else:
277 # create one hot vector of labels
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/series.py in __getitem__(self, key)
1111 key = check_bool_indexer(self.index, key)
1112
-> 1113 return self._get_with(key)
1114
1115 def _get_with(self, key):
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/series.py in _get_with(self, key)
1153 # handle the dup indexing case (GH 4246)
1154 if isinstance(key, (list, tuple)):
-> 1155 return self.loc[key]
1156
1157 return self.reindex(key)
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/indexing.py in __getitem__(self, key)
1422
1423 maybe_callable = com.apply_if_callable(key, self.obj)
-> 1424 return self._getitem_axis(maybe_callable, axis=axis)
1425
1426 def _is_scalar_access(self, key: Tuple):
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/indexing.py in _getitem_axis(self, key, axis)
1837 raise ValueError("Cannot index with multidimensional key")
1838
-> 1839 return self._getitem_iterable(key, axis=axis)
1840
1841 # nested tuple slicing
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/indexing.py in _getitem_iterable(self, key, axis)
1131 else:
1132 # A collection of keys
-> 1133 keyarr, indexer = self._get_listlike_indexer(key, axis, raise_missing=False)
1134 return self.obj._reindex_with_indexers(
1135 {axis: [keyarr, indexer]}, copy=True, allow_dups=True
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/indexing.py in _get_listlike_indexer(self, key, axis, raise_missing)
1090
1091 self._validate_read_indexer(
-> 1092 keyarr, indexer, o._get_axis_number(axis), raise_missing=raise_missing
1093 )
1094 return keyarr, indexer
~/virtualenvs/anaconda3/envs/pytorch/lib/python3.7/site-packages/pandas/core/indexing.py in _validate_read_indexer(self, key, indexer, axis, raise_missing)
1175 raise KeyError(
1176 "None of [{key}] are in the [{axis}]".format(
-> 1177 key=key, axis=self.obj._get_axis_name(axis)
1178 )
1179 )
KeyError: ("None of [Index(['CLASS_1', 'CLASS_2', 'CLASS_3', 'CLASS_4', 'CLASS_5', 'CLASS_6',\n 'CLASS_7', 'CLASS_8', 'CLASS_9', 'CLASS_10', 'CLASS_11', 'CLASS_12',\n 'CLASS_13', 'CLASS_14', 'CLASS_15', 'CLASS_16', 'CLASS_17', 'CLASS_E',\n 'CLASS_V'],\n dtype='object')] are in the [index]", 'occurred at index 0')
我确实在Google上搜索过,但发现的最接近的是this,但它仍然没有给我提示如何解决此问题。 我认为存在与label_col相关的错误,但无法识别该错误 我将不胜感激。 谢谢