减少重复的Django模型代码

时间:2014-10-20 20:35:53

标签: django

我有一个项目,我将外键值存储在加密的文本字段中。这样做的目的是将表分为两组,一组具有个人身份信息,另一组没有。具有此功能的每个模型都使用2个字段和2个方法实现它:

class User(AbstractBaseUser):
    # ... fields

    encrypted_identification_id = models.TextField(null=True)
    encrypted_identification_key = models.TextField(null=True)

    def get_identification(self, private_key):
        if not self.encrypted_identification_key:
            return None
        identification_id = decrypt(private_key, self.encrypted_identification_id, self.encrypted_identification_key)
        return Identification.objects.get(pk=identification_id)

    def set_identification(self, identification):
        encrypted = encrypt(str(identification.pk))
        self.encrypted_identification_id = encrypted['encrypted_string']
        self.encrypted_identification_key = encrypted['aes_key']
        self.save()


class Identification(models.Model):
    # ... fields

    encrypted_user_id = models.TextField(null=True)
    encrypted_user_key = models.TextField(null=True)

    def get_user(self, private_key):
        if not self.encrypted_user_key:
            return None
        user_id = decrypt(private_key, self.encrypted_user_id, self.encrypted_user_key)
        return User.objects.get(pk=user_id)

    def set_user(self, user):
        encrypted = encrypt(str(user.pk))
        self.encrypted_user_id = encrypted['encrypted_string']
        self.encrypted_user_key = encrypted['aes_key']
        self.save()

字段和方法的代码是相同的,只是它们具有和使用不同的名称。我有很多这样的模型都复制并粘贴了这样的代码。什么是减少这种重复的最有效方法?

2 个答案:

答案 0 :(得分:0)

您可以将相关位抽象为函数:

def _get_decrypted(id_field, key_field, model, obj, private_key):
    if not getattr(obj, key_field):
        return None
    decrypted_id = decrypt(private_key, getattr(obj, id_field), getattr(obj, key_field)
    return model.objects.get(pk=decrypted_id)

然后在每个模型中使用该函数:

def get_identification(self, private_key):
    return _get_decrypted('encrypted_identification_id', 'encrypted_identification_key', Identification, self, private_key)

def get_user(self, private_key):
    return _get_decrypted('encrypted_user_id', 'encrypted_user_key', User, self, private_key)

同样使用setattr而不是getattr为set版本编写函数。

您也可以将其转换为mixin类的方法。

答案 1 :(得分:0)

Ben的回答可以减少方法中的重复代码,但对重复的字段没有任何作用。

更好的技术是使用元类来创建所需的字段和方法:

class EncryptedRelationModelBase(ModelBase):
    related_model = None
    app_label = None
    model_name = None
    added_encrypted_attrs = False

    def __new__(cls, name, bases, attrs):
        encrypted_attrs = cls.get_encrypted_attrs()
        attrs.update(encrypted_attrs)
        return super(EncryptedRelationModelBase, cls).__new__(cls, name, bases,
                                                              attrs)

    @classmethod
    def get_encrypted_attrs(cls):
        # This method should add attributes only once
        if cls.added_encrypted_attrs:
            return {}

        ret = cls._get_encrypted_attrs()

        cls.added_encrypted_attrs = True
        return ret

    @classmethod
    def _get_encrypted_attrs(cls):
        def get_related_model(self, private_key):
            if not getattr(self, cls.get_key_field_name()):
                return None
            related_instance_id = shcrypto.decrypt(
                private_key,
                getattr(self, cls.get_id_field_name()),
                getattr(self, cls.get_key_field_name()))
            return cls.get_related_model().objects.get(pk=related_instance_id)

        def set_related_model(self, instance):
            encrypted = encrypt(str(instance.pk))
            setattr(self, cls.get_id_field_name(),
                    encrypted['encrypted_string'])
            setattr(self, cls.get_key_field_name(), encrypted['aes_key'])
            self.save()

        attrs = {}
        attrs[cls.get_id_field_name()] = models.TextField(null=True, blank=True)
        attrs[cls.get_key_field_name()] = models.TextField(null=True, blank=True)
        attrs[cls.get_get_method_name()] = get_related_model
        attrs[cls.get_set_method_name()] = set_related_model
        return attrs

    @classmethod
    def get_related_model(cls):
        if cls.related_model:
            return cls.related_model
        return get_model(cls.app_label, cls.model_name)

    @classmethod
    def format_related_model_name(cls):
        if cls.model_name:
            return cls.snake_case(cls.model_name)
        return cls.snake_case(cls.related_model.__name__)

    @staticmethod
    def snake_case(name):
        name = re.sub('(.)([A-Z](?!s[A-Z])[a-z]+)', r'\1_\2', name)
        name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', name)
        name = name.lower()
        name = name.replace('__', '_')
        return name

    @classmethod
    def get_id_field_name(cls):
        return 'encrypted_{}_id'.format(cls.format_related_model_name())

    @classmethod
    def get_key_field_name(cls):
        return 'encrypted_{}_key'.format(cls.format_related_model_name())

    @classmethod
    def get_get_method_name(cls):
        return 'get_{}'.format(cls.format_related_model_name())

    @classmethod
    def get_set_method_name(cls):
        return 'set_{}'.format(cls.format_related_model_name())


def EncryptedRelationModel(related_model=None, app_label=None, model_name=None):
    def create_metaclass(model, label, name):
        class NewModelBase(EncryptedRelationModelBase):
            related_model = model
            app_label = label
            model_name = name
        return NewModelBase

    class NewModel(models.Model):
        __metaclass__ = create_metaclass(related_model, app_label, model_name)

        class Meta:
            abstract = True
    return NewModel


class Identification(UserEncryptedRelationModel(User)):
    pass