如何加快数据库的写入速度?

时间:2019-03-05 11:16:51

标签: python django database sqlite

我有一个功能,可以在目录中搜索json文件,解析文件并将数据写入数据库。我的问题是写数据库,因为它大约需要30分钟。知道如何加快在数据库中的写入速度吗?我要解析的文件很少,但是解析文件不是问题,大约需要3分钟。目前,我正在使用sqlite,但将来我会将其更改为PostgreSQL。

这是我的功能:

def create_database():
    with transaction.atomic():
        directory = os.fsencode('data/web_files/unzip')
        for file in os.listdir(directory):
            filename = os.fsdecode(file)

            with open('data/web_files/unzip/{}'.format(filename.strip()), encoding="utf8") as f:
                data = json.load(f)
                cve_items = data['CVE_Items']
                for i in range(len(cve_items)):
                    database_object = DataNist()

                    try:
                        impact = cve_items[i]['impact']['baseMetricV2']
                        database_object.severity = impact['severity']
                        database_object.exp_score = impact['exploitabilityScore']
                        database_object.impact_score = impact['impactScore']
                        database_object.cvss_score = impact['cvssV2']['baseScore']

                    except KeyError:
                        database_object.severity = ''
                        database_object.exp_score = ''
                        database_object.impact_score = ''
                        database_object.cvss_score = ''

                    for vendor_data in cve_items[i]['cve']['affects']['vendor']['vendor_data']:
                            database_object.vendor_name = vendor_data['vendor_name']

                            for description_data in cve_items[i]['cve']['description']['description_data']:
                                database_object.description = description_data['value']

                            for product_data in vendor_data['product']['product_data']:
                                database_object.product_name = product_data['product_name']
                                database_object.save()

                                for version_data in product_data['version']['version_data']:
                                    if version_data['version_value'] != '-':
                                    database_object.versions_set.create(version=version_data['version_value'])

我的模型。py:

class DataNist(models.Model):
    vendor_name = models.CharField(max_length=100)
    product_name = models.CharField(max_length=100)
    description = models.TextField()
    date = models.DateTimeField(default=timezone.now)
    severity = models.CharField(max_length=10)
    exp_score = models.IntegerField()
    impact_score = models.IntegerField()
    cvss_score = models.IntegerField()


    def __str__(self):
        return self.vendor_name + "-" + self.product_name


class Versions(models.Model):
    data = models.ForeignKey(DataNist, on_delete=models.CASCADE)
    version = models.CharField(max_length=50)

    def __str__(self):
        return self.version

如果您能给我任何建议,我将如何改善我的代码,我将不胜感激。

1 个答案:

答案 0 :(得分:1)

好吧,鉴于数据的结构,类似这样的方法可能对您有用。

这是.objects.bulk_create()调用之外的独立代码;如代码中所述,定义的两个类实际上是Django应用程序中的模型。 (顺便说一句,您可能也希望将CVE ID保存为唯一字段。)

您的原始代码有一个错误的假设,即受影响的版本数据中的每个“叶子条目”都将具有相同的供应商,这可能不是正确的。因此,此处的模型结构具有一个单独的产品版本模型,其中包含供应商,产品和版本字段。 (如果您想稍微优化一些东西,甚至可以在AffectedProductVersion之间对DataNist进行重复数据删除(顺便说一句,它并不是模型的完美名称)。

当然,就像您在原始代码中所做的那样,导入应该在事务(transaction.atomic())中运行。

希望这会有所帮助。

import json
import os
import types


class DataNist(types.SimpleNamespace):  # this would actually be a model
    severity = ""
    exp_score = ""
    impact_score = ""
    cvss_score = ""

    def save(self):
        pass


class AffectedProductVersion(types.SimpleNamespace):  # this too
    # (foreign key to DataNist here)
    vendor_name = ""
    product_name = ""
    version_value = ""


def import_item(item):
    database_object = DataNist()
    try:
        impact = item["impact"]["baseMetricV2"]
    except KeyError:  # no impact object available
        pass
    else:
        database_object.severity = impact.get("severity", "")
        database_object.exp_score = impact.get("exploitabilityScore", "")
        database_object.impact_score = impact.get("impactScore", "")
        if "cvssV2" in impact:
            database_object.cvss_score = impact["cvssV2"]["baseScore"]

    for description_data in item["cve"]["description"]["description_data"]:
        database_object.description = description_data["value"]
        break  # only grab the first description

    database_object.save()  # save the base object

    affected_versions = []
    for vendor_data in item["cve"]["affects"]["vendor"]["vendor_data"]:
        for product_data in vendor_data["product"]["product_data"]:
            for version_data in product_data["version"]["version_data"]:
                affected_versions.append(
                    AffectedProductVersion(
                        data_nist=database_object,
                        vendor_name=vendor_data["vendor_name"],
                        product_name=product_data["product_name"],
                        version_name=version_data["version_value"],
                    )
                )

    AffectedProductVersion.objects.bulk_create(
        affected_versions
    )  # save all the version information

    return database_object  # in case the caller needs it


with open("nvdcve-1.0-2019.json") as infp:
    data = json.load(infp)
    for item in data["CVE_Items"]:
        import_item(item)