无法修复django rest框架中的嵌套值

时间:2018-01-17 17:44:59

标签: python django unit-testing django-rest-framework

我正在尝试使用django-rest-framework开发REST API来更新django模型。

我想用以下单元测试进行单元测试

from rest_framework.test import APITestCase
class PatchInvestmentTest(APITestCase):

    def test_repartition(self):

        investment = Investment.objects.create()
        sponsor1 = Investment.objects.create(InvestmentSponsor, name='A')
        sponsor2 = Investment.objects.create(InvestmentSponsor, name='B')

        url = reverse('investments:investments-detail', args=[investment.id])
        data = {
            'sponsorships': [
                {'sponsor': sponsor1.id, 'percentage': 80},
                {'sponsor': sponsor2.id, 'percentage': 10},
            ]
        }

        print("> data", data)

        response = self.client.patch(url, data=data)

        self.assertEqual(response.status_code, status.HTTP_200_OK)

        self.assertEqual(1, Investment.objects.count())
        investment = Investment.objects.all()[0]
        # It fails below : no investments are created
        self.assertEqual(len(investment.sponsorships()), 2) 

模型可以用

来概括
class Investment(models.Model):
    # ... a few fields

    def sponsorships(self):
        return self.investmentsponsorship_set.all().order_by('sponsor__ordering', 'sponsor__name')


class InvestmentSponsor(models.Model):
    name = models.CharField(max_length=200, verbose_name=_('name'))
    ordering = models.IntegerField(default=0)

    class Meta:
        ordering = ('ordering', 'name', )


class InvestmentSponsorship(models.Model):
    sponsor = models.ForeignKey(InvestmentSponsor)
    investment = models.ForeignKey(Investment)
    percentage = models.DecimalField(max_digits=5, decimal_places=2)

api正在使用rest-framework基类

class InvestmentViewSet(viewsets.ModelViewSet):
    model = Investment

    def get_serializer_class(self):
        serializers_class_map = {
            'default': InvestmentSerializer,
            'partial_update': PartialUpdateInvestmentSerializer,
        }
        return serializers_class_map.get(self.action, serializers_class_map['default'])

    def perform_update(self, serializer):
         serializer.save()

然后我期望得到并处理赞助商"序列化程序中的数据

class InvestmentSponsorshipSerializer(serializers.ModelSerializer):

    class Meta:
        model = models.InvestmentSponsorship
        fields = ('sponsor', 'percentage', )


class PartialUpdateInvestmentSerializer(serializers.ModelSerializer):
    sponsorships = InvestmentSponsorshipSerializer(many=True)

    class Meta:
        model = models.Investment
        fields = (
            'id', '... others', 'sponsorships',
        )

    def validate_sponsorships(self, value):
        print("validate_sponsorships", value)
        return value

    def update(self, instance, validated_data):
        """update only fields in data"""

        data = validated_data.copy()

        print("*** DATA", validated_data)

        instance.save()

        return instance

问题是我从串行器收到的数据是空的

 > data {'sponsorships': [{'sponsor': 1, 'percentage': 80}, {'sponsor': 2, 'percentage': 10}]}
 validate_sponsorships []
 *** DATA {'sponsorships': []}

这似乎只在单元测试时才会发生。它似乎可以从dango-rest-framework管理员那里工作。

我已经尝试调查为什么我没有在更新中收到数据为validated_data但尚未成功。

有什么想法吗?

2 个答案:

答案 0 :(得分:1)

您应该在致电patch时添加format parameter

        response = self.client.patch(url, data=data, format='json')

我认为默认multipart格式不支持嵌套。

答案 1 :(得分:0)

我通过使用常规django单元测试类找到了解决方案。有点困难,但它的工作原理

from django.test import TestCase

class PatchInvestmentTest(TestCase):

    def test_repartition(self):

        investment = Investment.objects.create()

        sponsor1 = mommy.make(InvestmentSponsor, name='A', ordering=3)
        sponsor2 = mommy.make(InvestmentSponsor, name='B', ordering=2)

        url = reverse('investments:investments-detail', args=[investment.id])

        data = {
            "sponsorships": [
                {"sponsor": sponsor2.id, "percentage": 80},
                {"sponsor": sponsor1.id, "percentage": 10},
            ]
        }

        date_as_json = json.dumps(data)

        response = self.client.patch(url, data=date_as_json, content_type="application/json")

        self.assertEqual(response.status_code, status.HTTP_200_OK)

        self.assertEqual(1, Investment.objects.count())
        investment = Investment.objects.all()[0]
        self.assertEqual(len(investment.sponsorships()), 2)