SQLAlchemy:指向同一个表的多个ForeignKeys,一些是可选的,启用了多态继承

时间:2014-10-08 20:45:33

标签: python sqlalchemy polymorphism

我为使用SQLAlchemy存储数据的应用程序创建了一些基类。其中一个基类(Content)是多态的,并且有一些常规字段,如id,title,description,timestamps等。此类的子类应该添加存储在单独表中的其他字段。我创建了一个独立的代码示例,更好地说明了这个概念。该示例包含Base类,一些子类和一些引导代码以创建sqlite数据库。通过将代码粘贴到示例中来运行示例的最简单方法是创建virtualenv,将SQLAlchemy安装到virtualenv中并使用它的解释器来运行示例。该示例包含一些注释的麻烦代码,如果该代码被注释,示例应该运行没有错误(至少它在这里)。

通过取消注释评论代码,示例失败了,我不太确定如何解决这个问题 - 任何帮助都是超级欢迎的!

示例概述:

  • 它有一些基类(基础和内容)。
  • 它有一个扩展内容的Task类。
  • 任务可能有子任务,位置排序应该保持不变。
  • 它有一个扩展内容的Project类(注释)。
  • 项目有一个due_date和里程碑(这是一个任务列表)
  • 它有一个扩展内容的工作列表类(注释)。
  • 工作清单属于'员工'并有任务。

我想要实现的是将Task作为一个独立的类工作,但是其他类也可能有任务(例如Project和Worklist)。我不想最终得到几个任务/相关表格,而是希望利用内容来实现这个概念,并在这个通用的#39;中附加任务。方式。

示例代码:

from datetime import datetime
from datetime import timedelta
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Boolean
from sqlalchemy import String
from sqlalchemy import DateTime
from sqlalchemy import Date
from sqlalchemy import Unicode
from sqlalchemy import UnicodeText
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref
from sqlalchemy.util import classproperty


class Base(object):

    @declared_attr
    def __tablename__(cls):
        return cls.__name__.lower()

    @property
    def columns(self):
        return self.__mapper__.columns.keys()

    def add(self, **data):
        self.update(**data)
        db_session.add(self)
        db_session.flush()

    def delete(self):
        db_session.delete(self)
        db_session.flush()

    def update(self, **data):
        """
        Iterate over all columns and set values from data.
        """
        for attr in self.columns:
            if attr in data and data[attr] is not None:
                setattr(self, attr, data[attr])


engine = create_engine('sqlite:///test.db', echo=True)
metadata = MetaData()
db_session = scoped_session(sessionmaker(bind=engine))

Base = declarative_base(cls=Base)
Base.metadata = metadata
Base.query = db_session.query_property()


class Content(Base):
    """
    Base class for all content. Includes basic features such as
    ownership and timestamps for modification and creation.
    """

    @classproperty
    def __mapper_args__(cls):
        return dict(
            polymorphic_on='type',
            polymorphic_identity=cls.__name__.lower(),
            with_polymorphic='*')

    id = Column(Integer(), primary_key=True)
    type = Column(String(30), nullable=False)
    owner = Column(Unicode(128))
    title = Column(Unicode(128))
    description = Column(UnicodeText())
    creation_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
    modification_date = Column(DateTime(), nullable=False, default=datetime.utcnow)

    def __init__(self, **data):
        self.add(**data)

    def update(self, touch=True, **data):
        """
        Iterate over all columns and set values from data.
        :param touch:
        :param data:
        :return:
        """
        super(Content, self).update(**data)
        if touch and 'modification_date' not in data:
            self.modification_date = datetime.utcnow()

    def __eq__(self, other):
        return isinstance(other, Content) and self.id == other.id


def get_content(id):
    return Content.query.get(id)


class Task(Content):

    id = Column(Integer, ForeignKey(Content.id), primary_key=True)
    # content_id = Column(Integer, ForeignKey(Content.id), nullable=True)

    done = Column(Boolean, default=False)
    position = Column(Integer, default=0)
    parent_id = Column(Integer, ForeignKey('task.id'), nullable=True)

    tasks = relationship(
        'Task',
        cascade='all, delete, delete-orphan',
        backref=backref('parent', remote_side=id),
        foreign_keys='Task.parent_id',
        order_by=position,
        collection_class=ordering_list('position', reorder_on_append=True)
    )

def default_due_date():
    return datetime.utcnow() + timedelta(days=60)


# class Project(Content):
#
#     id = Column(Integer, ForeignKey(Content.id), primary_key=True)
#     due_date = Column(Date, default=default_due_date)
#
#     milestones = relationship(
#         'Task',
#         cascade='all, delete, delete-orphan',
#         backref=backref('content_parent', remote_side=id),
#         foreign_keys='Task.content_id',
#         collection_class=ordering_list('position', reorder_on_append=True)
#     )
#
#
# class Worklist(Content):
#
#     id = Column(Integer, ForeignKey(Content.id), primary_key=True)
#     employee = Column(Unicode(128), nullable=False)
#
#     tasks = relationship(
#         'Task',
#         cascade='all, delete, delete-orphan',
#         backref=backref('content_parent', remote_side=id),
#         foreign_keys='Task.content_id',
#         collection_class=ordering_list('position', reorder_on_append=True)
#     )


def main():
    db_session.registry.clear()
    db_session.configure(bind=engine)
    metadata.bind = engine
    metadata.create_all(engine)

    # Test basic operation
    task = Task(title=u'Buy milk')
    task = get_content(task.id)

    # assert Content attributes inherited
    assert task.title == u'Buy milk'
    assert task.done == False

    # add subtasks
    task.tasks = [
        Task(title=u'Remember to check expiration date'),
        Task(title=u'Check bottle is not leaking')
    ]

    # assert that subtasks is added and correctly ordered
    task = get_content(task.id)
    assert len(task.tasks) == 2
    assert [(x.position, x.title) for x in task.tasks] == \
           [(0, u'Remember to check expiration date'),
            (1, u'Check bottle is not leaking')]

    # reorder subtasks
    task.tasks.insert(0, task.tasks.pop(1))
    task = get_content(task.id)
    assert len(task.tasks) == 2
    assert [(x.position, x.title) for x in task.tasks] == \
           [(0, u'Check bottle is not leaking'),
            (1, u'Remember to check expiration date')]

    # # Test Project implementation
    # project = Project(title=u'My project')
    # milestone1 = Task(title=u'Milestone #1', description=u'First milestone')
    # milestone2 = Task(title=u'Milestone #2', description=u'Second milestone')
    # milestone1.tasks = [Task(title=u'Subtask for Milestone #1'), ]
    # milestone2.tasks = [Task(title=u'Subtask #1 for Milestone #2'),
    #                     Task(title=u'Subtask #2 for Milestone #2')]
    # project.milestones = [milestone1, milestone2]
    # project = get_content(project.id)
    # assert project.title == u'My project'
    # assert len(project.milestones) == 2
    # assert [(x.position, x.title) for x in project.milestones] == \
    #        [(0, u'Milestone #1'), (1, u'Milestone #2')]
    # assert len(Task.query.all()) == 8
    # assert isinstance(milestone1.content_parent, Project) == True
    #
    # # Test Worklist implementation
    # worklist = Worklist(title=u'My worklist', employee=u'Torkel Lyng')
    # worklist.tasks = [
    #     Task(title=u'Ask stackoverflow for help'),
    #     Task(title=u'Learn SQLAlchemy')
    # ]
    # worklist = get_content(worklist.id)
    # assert worklist.title == u'My worklist'
    # assert worklist.employee == u'Torkel Lyng'
    # assert len(worklist.tasks) == 2
    # assert len(Task.query.all()) == 10
    # assert isinstance(worklist.tasks[0].content_parent, Worklist) == True


if __name__=='__main__':
    main()

对不起这个长长的例子,我想提供一些独立的东西。任何帮助,对设计或建议的评论都非常有用。

1 个答案:

答案 0 :(得分:0)

我稍微重构了一下这个例子并使它有点工作。我没有在Task(content_id)上定义额外的ForeignKey,而是将其作为container_id添加到Content类中

from datetime import datetime
from datetime import timedelta
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import Boolean
from sqlalchemy import String
from sqlalchemy import DateTime
from sqlalchemy import Date
from sqlalchemy import Unicode
from sqlalchemy import UnicodeText
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.ext.orderinglist import ordering_list
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref
from sqlalchemy.util import classproperty


class Base(object):

    @declared_attr
    def __tablename__(cls):
        return cls.__name__.lower()

    @property
    def columns(self):
        return self.__mapper__.columns.keys()

    def add(self, **data):
        self.update(**data)
        db_session.add(self)
        db_session.flush()

    def delete(self):
        db_session.delete(self)
        db_session.flush()

    def update(self, **data):
        """
        Iterate over all columns and set values from data.
        """
        for attr in self.columns:
            if attr in data and data[attr] is not None:
                setattr(self, attr, data[attr])


engine = create_engine('sqlite:///test.db', echo=True)
metadata = MetaData()
db_session = scoped_session(sessionmaker(bind=engine))

Base = declarative_base(cls=Base)
Base.metadata = metadata
Base.query = db_session.query_property()


class Content(Base):
    """
    Base class for all content. Includes basic features such as
    ownership and timestamps for modification and creation.
    """

    @classproperty
    def __mapper_args__(cls):
        return dict(
            polymorphic_on='type',
            polymorphic_identity=cls.__name__.lower(),
            with_polymorphic='*')

    id = Column(Integer(), primary_key=True)
    container_id = Column(Integer(), ForeignKey('content.id'), nullable=True)
    # container = relationship('Content', foreign_keys=[container_id], uselist=False)

    type = Column(String(30), nullable=False)
    owner = Column(Unicode(128))
    title = Column(Unicode(128))
    description = Column(UnicodeText())
    creation_date = Column(DateTime(), nullable=False, default=datetime.utcnow)
    modification_date = Column(DateTime(), nullable=False, default=datetime.utcnow)

    def __init__(self, **data):
        self.add(**data)

    @property
    def container(self):
        if self.container_id:
            return get_content(self.container_id)
        return None

    def update(self, touch=True, **data):
        """
        Iterate over all columns and set values from data.
        :param touch:
        :param data:
        :return:
        """
        super(Content, self).update(**data)
        if touch and 'modification_date' not in data:
            self.modification_date = datetime.utcnow()

    def __eq__(self, other):
        return isinstance(other, Content) and self.id == other.id

    def __repr__(self):
        return '<{0} "{1}">'.format(self.__class__.__name__, self.title)


def get_content(id):
    return Content.query.get(id)


class Task(Content):

    id = Column(Integer, ForeignKey(Content.id), primary_key=True)

    done = Column(Boolean, default=False)
    position = Column(Integer, default=0)
    parent_id = Column(Integer, ForeignKey('task.id'), nullable=True)

    tasks = relationship(
        'Task',
        cascade='all, delete, delete-orphan',
        backref=backref('parent', remote_side=id),
        foreign_keys='Task.parent_id',
        order_by=position,
        collection_class=ordering_list('position', reorder_on_append=True)
    )

def default_due_date():
    return datetime.utcnow() + timedelta(days=60)


class Project(Content):

    id = Column(Integer, ForeignKey(Content.id), primary_key=True)
    due_date = Column(Date, default=default_due_date)

    milestones = relationship(
        'Task',
        cascade='all, delete, delete-orphan',
        foreign_keys='Task.container_id',
        collection_class=ordering_list('position', reorder_on_append=True)
    )


class Worklist(Content):

    id = Column(Integer, ForeignKey(Content.id), primary_key=True)
    employee = Column(Unicode(128), nullable=False)

    tasks = relationship(
        'Task',
        cascade='all, delete, delete-orphan',
        foreign_keys='Task.container_id',
        collection_class=ordering_list('position', reorder_on_append=True)
    )


def main():
    db_session.registry.clear()
    db_session.configure(bind=engine)
    metadata.bind = engine
    metadata.create_all(engine)

    # Test basic operation
    task = Task(title=u'Buy milk')
    task = get_content(task.id)

    # assert Content attributes inherited
    assert task.title == u'Buy milk'
    assert task.done == False

    # add subtasks
    task.tasks = [
        Task(title=u'Remember to check expiration date'),
        Task(title=u'Check bottle is not leaking')
    ]

    # assert that subtasks is added and correctly ordered
    task = get_content(task.id)
    assert len(task.tasks) == 2
    assert [(x.position, x.title) for x in task.tasks] == \
           [(0, u'Remember to check expiration date'),
            (1, u'Check bottle is not leaking')]

    # reorder subtasks
    task.tasks.insert(0, task.tasks.pop(1))
    task = get_content(task.id)
    assert len(task.tasks) == 2
    assert [(x.position, x.title) for x in task.tasks] == \
           [(0, u'Check bottle is not leaking'),
            (1, u'Remember to check expiration date')]

    # Test Project implementation
    project = Project(title=u'My project')
    milestone1 = Task(title=u'Milestone #1', description=u'First milestone')
    milestone2 = Task(title=u'Milestone #2', description=u'Second milestone')
    milestone1.tasks = [Task(title=u'Subtask for Milestone #1'), ]
    milestone2.tasks = [Task(title=u'Subtask #1 for Milestone #2'),
                        Task(title=u'Subtask #2 for Milestone #2')]
    project.milestones = [milestone1, milestone2]
    project = get_content(project.id)
    assert project.title == u'My project'
    assert len(project.milestones) == 2
    assert [(x.position, x.title) for x in project.milestones] == \
           [(0, u'Milestone #1'), (1, u'Milestone #2')]
    assert len(Task.query.all()) == 8
    container = milestone1.container
    assert isinstance(container, Project) == True

    # Test Worklist implementation
    worklist = Worklist(title=u'My worklist', employee=u'Torkel Lyng')
    worklist.tasks = [
        Task(title=u'Ask stackoverflow for help'),
        Task(title=u'Learn SQLAlchemy')
    ]
    worklist = get_content(worklist.id)
    assert worklist.title == u'My worklist'
    assert worklist.employee == u'Torkel Lyng'
    assert len(worklist.tasks) == 2
    assert len(Task.query.all()) == 10
    assert isinstance(worklist.tasks[0].container, Worklist) == True

    # Cleanup
    task = Task.query.filter_by(title=u'Buy milk').one()
    task.delete()
    project.delete()
    worklist.delete()
    assert len(Task.query.all()) == 0


if __name__=='__main__':
    main()

Content-class上的container关系没有按预期工作,如果我没有指定task.container = somecontainer,则返回None。相反,我选择了一个属性方法,而不是返回None或容器对象。我将进一步调查该主题,或许找到更优化的解决方案。建议或替代解决方案仍然非常受欢迎。