我现在正在尝试为多标签文本分类问题拟合分类模型。
我有一个训练集X_train
,其中包含清除文本列表,例如
["I am constructing Markov chains with to states and inferring
transition probabilities empirically by simply counting how many
times I saw each transition in my raw data",
"I know the chips only of the players of my table and mine obviously I
also know the total number of chips the max and min amount chips the
players have and the average stackIs it possible to make an
approximation of my probability of winningI have,
...]
并训练一组y
对应于X_train
中每个文本的多个标签,例如
[['hypothesis-testing', 'statistical-significance', 'markov-process'],
['probability', 'normal-distribution', 'games'],
...]
现在,我想拟合一个模型,该模型可以预测文本集X_test
中的标签,该文本集的格式与X_train
相同。
我已经使用MultiLabelBinarizer
来转换标签,并使用TfidfVectorizer
来转换火车集中的已清除文本。
multilabel_binarizer = MultiLabelBinarizer()
multilabel_binarizer.fit(y)
Y = multilabel_binarizer.transform(y)
vectorizer = TfidfVectorizer(stop_words = stopWordList)
vectorizer.fit(X_train)
x_train = vectorizer.transform(X_train)
但是当我尝试拟合模型时,总是会遇到错误。我尝试过OneVsRestClassifier
和LogisticRegression
。
当我使用OneVsRestClassifier
模型时,会遇到
Traceback (most recent call last):
File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 317, in _handle_request_noblock
self.process_request(request, client_address)
File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 348, in process_request
self.finish_request(request, client_address)
File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 361, in finish_request
self.RequestHandlerClass(request, client_address, self)
File "/opt/conda/envs/data3/lib/python3.6/socketserver.py", line 696, in __init__
self.handle()
File "/usr/local/spark/python/pyspark/accumulators.py", line 268, in handle
poll(accum_updates)
File "/usr/local/spark/python/pyspark/accumulators.py", line 241, in poll
if func():
File "/usr/local/spark/python/pyspark/accumulators.py", line 245, in accum_updates
num_updates = read_int(self.rfile)
File "/usr/local/spark/python/pyspark/serializers.py", line 714, in read_int
raise EOFError
EOFError
当我使用LogisticRegression
模型时,会遇到
/opt/conda/envs/data3/lib/python3.6/site-packages/sklearn/linear_model/sag.py:326: ConvergenceWarning: The max_iter was reached which means the coef_ did not converge
"the coef_ did not converge", ConvergenceWarning)
任何人都知道问题出在哪里以及如何解决?非常感谢。
答案 0 :(得分:0)
OneVsRestClassifier适合每个类别一个分类器。您需要告诉它所需的分类器类型(例如Losgistic回归)。
以下代码对我有用:
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
classifier = OneVsRestClassifier(LogisticRegression())
classifier.fit(x_train, Y)
X_test= ["I play with Markov chains"]
x_test = vectorizer.transform(X_test)
classifier.predict(x_test)
输出:array([[0,1,1,0,0,1]])