如何将两个要素/分类器组合成一个统一且更好的分类器?

时间:2015-07-01 15:56:46

标签: python matrix scikit-learn classification feature-selection

我正在创建一个情感识别程序,并设法生成两种不同的算法/功能,以提供给sklearn的SVM。我获得了密集光流数据,将其压缩成矩阵并将其输入SVM功能,同时我也使用面部地标跟踪数据。

现在,我有两个不同的程序,每个程序产生不同的准确性,但做同样的事情:根据面部运动识别情绪。

我现在的目标是将密集光流和面部地标特征/分类器结合在一起并将它们联合起来,以获得一个更好的分类器,使用这两个分类器来实现更高的分类准确度。

基本上,我试图从这里重新创建一个分类器: http://cgit.nutn.edu.tw:8080/cgit/PaperDL/LZJ_120826151743.PDF

Confusion Matrix for Dense Optical Flow:   
[[27 22  0  0]  
 [ 0 57  1  0]  
 [ 0 12 60  0]  
 [ 0  9  3 68]]  
Accuracy: 80-90% range

Confusion Matrix for Facial Landmarks:  
[[27 10  5  2]  
 [ 7 44  5  3]  
 [ 6 14 33  1]  
 [ 1 13  1 60]]  
Accuracy: 60-72% range

包含密集光流数据的矩阵的矩阵结构:

>>> main.shape
(646, 403680)
>>> main
array([[ -1.18353125e-03,  -2.41295085e-04,  -1.88367767e-03, ...,
         -5.19892928e-05,   8.53588153e-06,  -3.90818786e-05],
       [  6.32877424e-02,  -7.24349543e-02,   8.19472596e-02, ...,
         -4.71765925e-05,   5.41217596e-05,  -3.12083102e-05],
       [ -1.66368652e-02,   2.50510368e-02,  -6.03965335e-02, ...,
         -9.85100851e-05,  -7.69595645e-05,  -7.09727174e-05],
       ..., 
       [ -3.44874617e-03,   5.31123485e-03,  -8.47499538e-03, ...,
         -2.77953018e-06,  -2.96417579e-06,  -1.51305017e-06],
       [  3.24894954e-03,   5.05338283e-03,   3.91049543e-03, ...,
         -3.23493354e-04,   1.30995919e-04,  -3.06804082e-04],
       [  7.82454386e-03,   1.69946514e-02,   8.11014231e-03, ...,
         -1.02751539e-03,   7.68289610e-05,  -7.82517891e-04]], dtype=float32)  

包含人脸地标跟踪信息的矩阵的矩阵结构:

>>> main.shape
(646, 17, 68, 2)
>>> main
array([[[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[ -2.23606798,  -1.10714872],
         [ -2.23606798,  -1.10714872],
         [  3.        ,   1.        ],
         ..., 
         [  1.41421356,   0.78539816],
         [  1.41421356,   0.78539816],
         [  1.        ,   0.        ]],

        [[  2.82842712,  -0.78539816],
         [  2.23606798,  -1.10714872],
         [  2.23606798,  -1.10714872],
         ..., 
         [ -1.        ,  -0.        ],
         [ -1.        ,  -0.        ],
         [ -1.        ,  -0.        ]],

        ..., 
        [[  2.        ,   1.        ],
         [ -2.23606798,   1.10714872],
         [ -3.16227766,   1.24904577],
         ..., 
         [ -1.        ,  -0.        ],
         [ -1.41421356,   0.78539816],
         [ -1.        ,  -0.        ]],

        [[ -1.41421356,  -0.78539816],
         [  1.        ,   1.        ],
         [ -1.41421356,  -0.78539816],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[  3.        ,   1.        ],
         [  4.        ,   1.        ],
         [  4.        ,   1.        ],
         ..., 
         [  1.41421356,  -0.78539816],
         [  1.        ,   0.        ],
         [  1.        ,   0.        ]]],


       [[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[  1.        ,   1.        ],
         [ -1.41421356,  -0.78539816],
         [ -1.        ,  -0.        ],
         ..., 
         [  2.        ,   0.        ],
         [  1.        ,   0.        ],
         [ -1.        ,  -0.        ]],

        [[  0.        ,   0.        ],
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         ..., 
         [ -4.        ,  -0.        ],
         [ -3.        ,  -0.        ],
         [ -2.        ,  -0.        ]],

        ..., 
        [[ -2.23606798,  -1.10714872],
         [ -2.23606798,  -1.10714872],
         [  2.        ,   1.        ],
         ..., 
         [  0.        ,   0.        ],
         [  1.41421356,   0.78539816],
         [  1.41421356,   0.78539816]],

        [[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [ -1.        ,  -0.        ],
         ..., 
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         [ -1.41421356,   0.78539816]],

        [[  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [  1.        ,   1.        ],
         ..., 
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         [  1.        ,   0.        ]]],


       [[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[  3.16227766,   1.24904577],
         [  2.23606798,   1.10714872],
         [  2.23606798,   1.10714872],
         ..., 
         [ -1.41421356,  -0.78539816],
         [ -1.        ,  -0.        ],
         [ -1.41421356,   0.78539816]],

        [[ -1.41421356,   0.78539816],
         [  0.        ,   0.        ],
         [  1.41421356,   0.78539816],
         ..., 
         [ -1.41421356,   0.78539816],
         [ -1.        ,  -0.        ],
         [  0.        ,   0.        ]],

        ..., 
        [[  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         ..., 
         [  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [ -1.41421356,   0.78539816]],

        [[  1.        ,   1.        ],
         [  2.        ,   1.        ],
         [  2.23606798,   1.10714872],
         ..., 
         [  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [ -1.41421356,  -0.78539816]],

        [[  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [  1.        ,   1.        ],
         ..., 
         [ -2.        ,  -0.        ],
         [ -2.        ,  -0.        ],
         [ -1.        ,  -0.        ]]],


       ..., 
       [[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[  1.41421356,   0.78539816],
         [  1.41421356,   0.78539816],
         [  1.41421356,   0.78539816],
         ..., 
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         [  1.        ,   1.        ]],

        [[  5.        ,   1.        ],
         [ -4.12310563,   1.32581766],
         [ -4.12310563,   1.32581766],
         ..., 
         [  1.        ,   1.        ],
         [  0.        ,   0.        ],
         [  1.        ,   1.        ]],

        ..., 
        [[  3.16227766,   1.24904577],
         [  2.        ,   1.        ],
         [  2.        ,   1.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[ -3.16227766,   1.24904577],
         [  2.        ,   1.        ],
         [ -2.23606798,   1.10714872],
         ..., 
         [  0.        ,   0.        ],
         [  1.        ,   0.        ],
         [  1.        ,   0.        ]],

        [[  1.        ,   1.        ],
         [  1.        ,   1.        ],
         [  1.41421356,   0.78539816],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]]],


       [[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[ -2.23606798,   0.46364761],
         [ -1.41421356,   0.78539816],
         [ -2.23606798,   0.46364761],
         ..., 
         [  1.        ,   0.        ],
         [  1.        ,   0.        ],
         [  1.        ,   1.        ]],

        [[ -2.23606798,  -0.46364761],
         [ -1.41421356,  -0.78539816],
         [  2.        ,   1.        ],
         ..., 
         [  0.        ,   0.        ],
         [  1.        ,   0.        ],
         [  1.        ,   0.        ]],

        ..., 
        [[  1.        ,   0.        ],
         [  1.        ,   1.        ],
         [ -2.23606798,  -1.10714872],
         ..., 
         [ 19.02629759,   1.51821327],
         [ 19.        ,   1.        ],
         [-19.10497317,  -1.46591939]],

        [[  3.60555128,   0.98279372],
         [  3.60555128,   0.5880026 ],
         [  5.        ,   0.64350111],
         ..., 
         [  7.28010989,  -1.29249667],
         [  7.61577311,  -1.16590454],
         [  8.06225775,  -1.05165021]],

        [[ -7.28010989,   1.29249667],
         [ -5.        ,   0.92729522],
         [ -5.83095189,   0.5404195 ],
         ..., 
         [ 20.09975124,   1.47112767],
         [ 21.02379604,   1.52321322],
         [-20.22374842,  -1.42190638]]],


       [[[  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         ..., 
         [  0.        ,   0.        ],
         [  0.        ,   0.        ],
         [  0.        ,   0.        ]],

        [[ -1.41421356,   0.78539816],
         [ -2.23606798,   1.10714872],
         [  2.        ,   1.        ],
         ..., 
         [  1.        ,   1.        ],
         [  1.        ,   0.        ],
         [  2.23606798,  -0.46364761]],

        [[  1.        ,   0.        ],
         [  1.41421356,   0.78539816],
         [  1.        ,   1.        ],
         ..., 
         [  0.        ,   0.        ],
         [  1.        ,   1.        ],
         [  0.        ,   0.        ]],

        ..., 
        [[ -1.41421356,  -0.78539816],
         [  0.        ,   0.        ],
         [  1.        ,   1.        ],
         ..., 
         [  1.        ,   0.        ],
         [  1.        ,   0.        ],
         [  1.        ,   0.        ]],

        [[  1.        ,   1.        ],
         [ -1.        ,  -0.        ],
         [  1.        ,   1.        ],
         ..., 
         [ -1.        ,  -0.        ],
         [  0.        ,   0.        ],
         [ -1.        ,  -0.        ]],

        [[  0.        ,   0.        ],
         [  1.41421356,  -0.78539816],
         [ -1.        ,  -0.        ],
         ..., 
         [  1.        ,   0.        ],
         [  0.        ,   0.        ],
         [  1.        ,   0.        ]]]])

我的密集光流分类器代码:

features_train, features_test, labels_train, labels_test = cross_validation.train_test_split(main, target, test_size = 0.4)

# Determine amount of time to train
t0 = time()
model = SVC(probability=True)
#model = SVC(kernel='poly')
#model = GaussianNB()

model.fit(features_train, labels_train)

print 'training time: ', round(time()-t0, 3), 's'

# Determine amount of time to predict
t1 = time()
pred = model.predict(features_test)

我的脸部地标跟踪分类器代码:

features_train, features_test, labels_train, labels_test = cross_validation.train_test_split(main.reshape(len(main), -1), target, test_size = 0.4)

# Determine amount of time to train
t0 = time()
#model = SVC()
model = SVC(kernel='linear')

#model = GaussianNB()

model.fit(features_train, labels_train)


# Determine amount of time to predict
t1 = time()
pred = model.predict(features_test)

在sklearn(或一般的机器学习)中,如何将这两个功能组合在一起,以创建一个统一且更好的分类器,在训练和预测时将这两个信息考虑在内?

2 个答案:

答案 0 :(得分:3)

看看VotingClassifier。它允许您组合多个分类器,并根据每个分类器的各个预测选择最终预测。

以上假设你可以使用sklearn的dev(0.17)版本。如果没有,您可以将VotingClassifier源复制到您的代码中:https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/ensemble/voting_classifier.py。代码非常简单。

答案 1 :(得分:0)

您可以根据Daniel的建议构建单独的分类器。但是,您可以考虑连接两个数据集:

main_dense_optical
main_face_landmark = main_face_landmark.reshape( len(main_face_landmark), -1 )
main = np.concatenate( [main_dense_optical, main_face_landmark], axis=1 )
# Code for train/test, training, evaluating here