使用OpenCV训练线性SVM模型以识别一张脸

时间:2019-11-05 23:52:16

标签: opencv machine-learning scikit-learn opencv4

我正在尝试遵循this tutorial来创建python程序,以识别视频流中的人脸并识别经过深度学习模型训练的人脸。

该程序最终将用作我正在创建的应用程序的一种身份验证(我意识到这将是不安全的),在该应用程序中,用户将首先使用其网络摄像头拍摄这些照片,然后这些图像将用于训练深度学习模型,以在用户打开程序进行身份验证时识别该用户。

我现在遇到的问题是,在本教程的步骤2中,我应该训练一个线性SVM模型来识别用户,但是我想我不能使用一个来训练机器学习模型。单班。那么我的问题是,由于人脸数据集中只有一个人脸,我该如何训练机器学习模型来识别应用程序的第一个用户?这意味着当我创建LabelEncoder()时将只有一个标签,并且将导致对SVC.fit()的调用失败,并显示以下信息:

ValueError: The number of classes has to be greater than one; got 1 class

这是我的train_model.py文件,其中包含相关的训练算法:

from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC

import common
import data_handler

def train_and_save(facial_embeddings_database: str= common.EMBEDDINGS_LOC, output_file: str = common.RECOGNITION_DATABASE_LOC) -> None:
    """
    Trains the database using the given facial embeddings database and outputs the results to file.
    :param facial_embeddings_database: The facial embedding database location.
    :param output_file: The file location for the output of the database.
    :return: None
    """
    database = data_handler.load_database(facial_embeddings_database)
    data_handler.write_database(output_file, train_model(database))


def train_model(facial_embeddings: dict) -> SVC:
    """
    Trains the model for the given database
    :param facial_embeddings_database: The location of the pickle database.
    :param output_file: File location where to output the pickle database of facial recognitions.
    :return:
    """
    label_encoder = LabelEncoder()
    X = []
    for user_id, encodings in facial_embeddings.items():
        X.extend([user_id for x in range(len(encodings))])

    # The facial_embeddings
    labels = label_encoder.fit_transform(X)

    recognizer = SVC(C=1.0, kernel="linear", probability=True)

    # TODO not too sure this line does what is intended.
    # THIS IS WHAT FAILS WHEN I ONLY HAVE ONE LABEL
    recognizer.fit(data_handler.get_encodings_in_database(facial_embeddings), labels)

    return recognizer

if __name__ == "__main__":
    train_and_save()

0 个答案:

没有答案