在DRF ModelSerializer中优化查询数量

时间:2018-07-06 12:05:37

标签: django django-rest-framework

在Django Rest Framework的序列化程序中,可以向序列化对象添加比原始模型更多的数据。 这对于在服务器端计算统计信息以及在响应API调用时添加此额外信息很有用。

据我了解,添加额外的数据是使用SerializerMethodField完成的,其中每个字段都是通过get_...函数实现的。 但是,如果有多个SerializerMethodField,则每个字段都可以分别查询Model /数据库,以获取实质上相同的数据。

是否可以查询一次数据库,将列表/结果存储为ModelSerializer对象的数据成员,并在许多函数中使用queryset的结果?

这是一个非常简单的示例,仅供说明:

############## Model

class Employee(Model):
    SALARY_TYPE_CHOICES = (('HR', 'Hourly Rate'), ('YR', 'Annual Salary'))
    salary_type = CharField(max_length=2, choices=SALARY_TYPE_CHOICES, blank=False)
    salary = PositiveIntegerField(blank=True, null=True, default=0)
    company = ForeignKey(Company, related_name='employees')

class Company(Model):
    name = CharField(verbose_name='company name', max_length=100)


############## View

class CompanyView(RetrieveAPIView):
    queryset = Company.objects.all()
    lookup_field='id'
    serializer_class = CompanySerialiser

class CompanyListView(ListAPIView):
    queryset = Company.objects.all()
    serializer_class = CompanySerialiser


############## Serializer

class CompanySerialiser(ModelSerializer):
    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def get_number_employees(self, obj):
        return obj.employees.count()
    def get_total_salaries_estimate(self, obj):
        employee_list = obj.employees.all()
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

串行器可以优化为:

  • 使用对象数据成员存储查询集中的结果,
  • 只检索一次查询集,
  • 将SerialsetMethodFields中提供的所有其他信息重新使用查询集的结果。

示例:

class CompanySerialiser(ModelSerializer):
    def __init__(self, *args, **kwargs):
        super(CompanySerialiser, self).__init__(*args, **kwargs)
        self.employee_list = None

    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def _populate_employee_list(self, obj):
        if not self.employee_list: # Query the database only once.
            self.employee_list = obj.employees.all()
    def get_number_employees(self, obj):
        self._populate_employee_list(obj)
        return len(self.employee_list)
    def get_total_salaries_estimate(self, obj):
        self._populate_employee_list(obj)
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in self.employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

这适用于单个检索CompanyView。并且,实际上将一个查询/上下文切换/往返保存到数据库;我已经消除了“计数”查询。

但是,对于列表视图CompanyListView,它不起作用,因为似乎序列化程序对象只创建了一次,并为每个公司重用。因此,只有第一公司的雇员列表存储在对象“ self.employee_list”数据成员中,因此,所有其他公司错误地从第一公司获得了数据。

是否存在针对此类问题的最佳实践解决方案?或者我只是错误地使用ListAPIView,如果是这样,还有替代方法吗?

2 个答案:

答案 0 :(得分:1)

我认为,如果您可以将查询集传递给已经提取了数据的CompanySerialiser,则可以解决此问题。

您可以进行以下更改

class CompanyListView(ListAPIView):
    queryset = Company.objects.all().prefetch_related('employee_set')
    serializer_class = CompanySerialiser`

使用len函数代替count,因为count会再次执行查询。

class CompanySerialiser(ModelSerializer):
    number_employees = SerializerMethodField()
    total_salaries_estimate = SerializerMethodField()
    class Meta:
        model = Company
        fields = ['id', 'name',
                  'number_employees',
                  'total_salaries_estimate',
                 ]
    def get_number_employees(self, obj):
        return len(obj.employees.all())
    def get_total_salaries_estimate(self, obj):
        employee_list = obj.employees.all()
        salaries_estimate = 0
        HOURS_PER_YEAR = 8*200 # 8hrs/day, 200days/year
        for empl in employee_list:
            if empl.salary_type == 'YR':
                salaries_estimate += empl.salary
            elif empl.salary_type == 'HR':
                salaries_estimate += empl.salary * HOURS_PER_YEAR
        return salaries_estimate

由于已预取数据,因此序列化程序将不会对all进行任何其他查询。但是请确保您没有进行任何筛选,因为在这种情况下将执行另一个查询。

答案 1 :(得分:0)

如@Ritesh Agrawal所述,您只需要预取数据即可。但是,我建议直接在数据库内部进行聚合,而不要使用Python:

class CompanySerializer(ModelSerializer):
    number_employees = IntegerField()
    total_salaries_estimate = FloatField()

    class Meta:
    model = Company
    fields = ['id', 'name',
              'number_employees',
              'total_salaries_estimate', ...
             ]

class CompanyListView(ListAPIView):
    queryset = Company.objects.annotate(
       number_employees=Count('employees'),
       total_salaries_estimate=Sum(
           Case(
               When(employees__salary_type=Value('HR'),
                    then=F('employees_salary') * Value(8 * 200)
               ),
               default=F('employees__salary'),
               output_field=IntegerField() #optional a priori, because you only manipulate integers
           )
        )
    )
    serializer_class = CompanySerializer

注意:

  • 我尚未测试此代码,但是我在自己的项目中使用了相同的语法。如果遇到错误(例如“无法确定输出类型”或类似错误),请尝试将F('employees_salary') * Value(8 * 200)包装在ExpressionWrapper(..., output_field=IntegerField())内。
  • 使用聚合,之后可以在查询集上应用过滤器。但是,如果您要预取相关的Employee,则无法再过滤相关的对象(如上一个答案中所述)。但是,如果您已经知道需要按小时收费的员工列表,则可以执行.prefetch_related(Prefetch('employees', queryset=Employee.object.filter(salary_type='HR'), to_attr="hourly_rate_employees"))

相关文档: Query optimization Aggregation

希望这会对您有所帮助;)