我正在使用LIME来解释多类分类Xgboost模型的结果。目标包含6个标签,这些标签已使用LabelEncoder进行了编码。
from lime import lime_text
from lime.lime_text import LimeTextExplainer
print(le_label, type(le_label[0]))
explainer = LimeTextExplainer(class_names=le_label)
exp = explainer.explain_instance(X_test.iloc[0][0],
tc.predict_proba,num_features=10,labels=le_label)
# print(exp.available_labels())
# print(exp.as_list())
print(exp)
exp.show_in_notebook(text=False)
错误
[0 1 2 3 4 5] <class 'numpy.int64'>
/home/joe/anaconda3/envs/spotlight/lib/python3.6/re.py:212: FutureWarning:
split() requires a non-empty pattern match.
[0, 1, 2, 3, 4, 5]
<lime.explanation.Explanation object at 0x7f921cff10b8>
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-79-0416337174b0> in <module>
17 # print(exp.as_list())
18 print(exp)
---> 19 exp.show_in_notebook(text=False)
20
21 # exp_js = '''var exp_div;
~/.local/lib/python3.6/site-packages/lime/explanation.py in show_in_notebook(self, labels, predict_proba, show_predicted_value, **kwargs)
198 predict_proba=predict_proba,
199 show_predicted_value=show_predicted_value,
--> 200 **kwargs)))
201
202 def save_to_file(self,
~/.local/lib/python3.6/site-packages/lime/explanation.py in as_html(self, labels, predict_proba, show_predicted_value, **kwargs)
286 exp_js = '''var exp_div;
287 var exp = new lime.Explanation(%s);
--> 288 ''' % (jsonize(self.class_names))
289
290 if self.mode == "classification":
~/.local/lib/python3.6/site-packages/lime/explanation.py in jsonize(x)
244
245 def jsonize(x):
--> 246 return json.dumps(x, ensure_ascii=False)
247
248 if labels is None and self.mode == "classification":
~/anaconda3/envs/spotlight/lib/python3.6/json/__init__.py in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
236 check_circular=check_circular, allow_nan=allow_nan, indent=indent,
237 separators=separators, default=default, sort_keys=sort_keys,
--> 238 **kw).encode(obj)
239
240
~/anaconda3/envs/spotlight/lib/python3.6/json/encoder.py in encode(self, o)
197 # exceptions aren't as detailed. The list call should be roughly
198 # equivalent to the PySequence_Fast that ''.join() would do.
--> 199 chunks = self.iterencode(o, _one_shot=True)
200 if not isinstance(chunks, (list, tuple)):
201 chunks = list(chunks)
~/anaconda3/envs/spotlight/lib/python3.6/json/encoder.py in iterencode(self, o, _one_shot)
255 self.key_separator, self.item_separator, self.sort_keys,
256 self.skipkeys, _one_shot)
--> 257 return _iterencode(o, 0)
258
259 def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,
~/anaconda3/envs/spotlight/lib/python3.6/json/encoder.py in default(self, o)
178 """
179 raise TypeError("Object of type '%s' is not JSON serializable" %
--> 180 o.__class__.__name__)
181
182 def encode(self, o):
TypeError: Object of type 'ndarray' is not JSON serializable
我已经提到了这个github问题(https://github.com/marcotcr/lime/issues/272),并更改了标签数据类型tp np.float64,但是随后出现以下错误。
新错误(我知道这表示将标签转换为整数dtype):
[0. 1. 2. 3. 4. 5.] <class 'numpy.float64'>
/home/joe/anaconda3/envs/spotlight/lib/python3.6/re.py:212: FutureWarning:
split() requires a non-empty pattern match.
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-83-ab75566a3ac2> in <module>
11
12 exp = explainer.explain_instance(X_test.iloc[0][0],
---> 13 tc.predict_proba,num_features=10,labels=le_label)
14
15 # print(exp.available_labels())
~/.local/lib/python3.6/site-packages/lime/lime_text.py in explain_instance(self, text_instance, classifier_fn, labels, top_labels, num_features, num_samples, distance_metric, model_regressor)
415 data, yss, distances, label, num_features,
416 model_regressor=model_regressor,
--> 417 feature_selection=self.feature_selection)
418 return ret_exp
419
~/.local/lib/python3.6/site-packages/lime/lime_base.py in explain_instance_with_data(self, neighborhood_data, neighborhood_labels, distances, label, num_features, feature_selection, model_regressor)
151
152 weights = self.kernel_fn(distances)
--> 153 labels_column = neighborhood_labels[:, label]
154 used_features = self.feature_selection(neighborhood_data,
155 labels_column,
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices