如何使用scikit.learn将字符串列表用作svm的训练数据?

时间:2013-12-02 18:17:43

标签: python-2.7 nlp svm scikit-learn

我正在使用scikit.learn来训练基于数据的svm,其中每个观察(X)是单词列表。每个观察(Y)的标记是浮点值。我尝试按照scikit学习文档(http://scikit-learn.org/stable/modules/svm.html)中给出的示例进行多类分类。 这是我的代码:

from __future__ import division
from sklearn import svm
import os.path
import numpy

import re

'''
The stanford-postagger was included to see how it tags the words and to see if it would help in getting just the names
of the ingredients. Turns out its pointless.
'''
#from nltk.tag.stanford import POSTagger
mainDirectory = './nyu/PROJECTS/Epicurious/DATA/ingredients'
#st = POSTagger('/usr/share/stanford-postagger/models/english-bidirectional-distsim.tagger','/usr/share/stanford-postagger/stanford-postagger.jar')

'''
This is where we would reach each line of the file and then run a regex match on it to get all the words before
the first tab. (these are the names of the ingredients. Some of them may have adjectives like fresh, peeled,cut etc.
    Not sure what to do about them yet.)


'''
def getFileDetails(_filename,_fileDescriptor):
    rankingRegexMatch = re.match('([0-9](?:\_)[0-9]?)', _filename)

    if len(rankingRegexMatch.group(0)) == 2:
        ranking = float(rankingRegexMatch.group(0)[0])
    else:
        ranking = float(rankingRegexMatch.group(0)[0]+'.'+rankingRegexMatch.group(0)[2])

    _keywords = []
    for line in _fileDescriptor:
        m = re.match('(\w+\s*\w*)(?=\t[0-9])', line)
        if m:
            _keywords.append(m.group(0))

    return [_keywords,ranking]

'''
Open each file in the directory and pass the name and file descriptor to getFileDetails
'''
def this_is_it(files):
    _allKeywords = []
    _allRankings = []
    for eachFile in files:
        fullFilePath = mainDirectory + '/' + eachFile
        f = open(fullFilePath)
        XandYForThisFile = getFileDetails(eachFile,f)
        _allKeywords.append(XandYForThisFile[0])
        _allRankings.append(XandYForThisFile[1])
    #_allKeywords = numpy.array(_allKeywords,dtype=object)
    svm_learning(_allKeywords,_allRankings)



def svm_learning(x,y):
    clf = svm.SVC()
    clf.fit(x,y)
'''
This just prints the directory path and then calls the callback x on files
'''
def print_files( x, dir_path , files ):
    print dir_path
    x(files)
'''
code starts here
'''
os.path.walk(mainDirectory, print_files, this_is_it)

当调用svm_learning(x,y)方法时,它会抛出一个错误:

Traceback (most recent call last):
  File "scan for files.py", line 72, in <module>
    os.path.walk(mainDirectory, print_files, this_is_it)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/posixpath.py", line 238, in walk
    func(arg, top, names)
  File "scan for files.py", line 68, in print_files
    x(files)
  File "scan for files.py", line 56, in this_is_it
    svm_learning(_allKeywords,_allRankings)
  File "scan for files.py", line 62, in svm_learning
    clf.fit(x,y)
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-intel.egg/sklearn/svm/base.py", line 135, in fit
    X = atleast2d_or_csr(X, dtype=np.float64, order='C')
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-intel.egg/sklearn/utils/validation.py", line 116, in atleast2d_or_csr
    "tocsr")
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-intel.egg/sklearn/utils/validation.py", line 96, in _atleast2d_or_sparse
    X = array2d(X, dtype=dtype, order=order, copy=copy)
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-intel.egg/sklearn/utils/validation.py", line 80, in array2d
    X_2d = np.asarray(np.atleast_2d(X), dtype=dtype, order=order)
  File "/Library/Python/2.7/site-packages/numpy-1.8.0.dev_bbcfcf6_20130307-py2.7-macosx-10.8-intel.egg/numpy/core/numeric.py", line 331, in asarray
    return array(a, dtype, copy=False, order=order)
ValueError: setting an array element with a sequence.

有人可以帮忙吗?我是scikit的新手,在文档中找不到任何帮助。

1 个答案:

答案 0 :(得分:0)

你应该看一下:Text feature extraction。您将要使用TfidfVectorizer,CountVectorizer或HashingVectorizer(如果您的数据非常大)。这些组件将您的文本输入并输出分类器可接受的特征矩阵。请注意,这些工作在字符串列表上,每个示例包含一个字符串,因此如果您有一个字符串列表列表(您已经被标记化),您可能需要加入()标记以获取字符串列表或跳过标记化。