在sqlalchemy中加入两张没有相关性的表格

时间:2017-11-01 00:40:24

标签: python sql sqlalchemy

这是我的问题。我有三张桌子。

一个名为Project的项目只有一个名为id的列(这在整个系统中必须是唯一的) 一个名为ServiceAwarenessProject,与Project.id有一对一的关系 一个名为CorporateVPNProject,与Project.id有一对一的关系

我正在使用sqlalchemy ORM,所以代码如下所示:

class Project(SqlAlchemyBase):
    __tablename__ = 'project'

    id = Column(Integer, primary_key=True, autoincrement=True)


class ServiceAwarenessProject(SqlAlchemyBase):
    __tablename__ = 'sa_project'

    id = Column(Integer, primary_key=True)
    project_id = Column(Integer, ForeignKey(Project.id))
    mop_url = Column(String, nullable=False)
    expiration_date = Column(Datetime, index=True)


class CorporateVPNProject(SqlAlchemyBase):
    __tablename__ = 'wvpn_project'

    id = Column(Integer, primary_key=True)
    project_id = Column(Integer, ForeignKey(Project.id))
    mop_url = Column(String, nullable=False)

我设计了这样的表格,所以我可以保证在整个系统中都有独特的project_ids。我的问题是我不知道如何将这些表连接在一起以找到基于project_id的项目。为了解决这个问题,我正在使用名为get_project_by_id的函数查询这两个表。 有没有更聪明的方法来解决这个问题?

class ProjectService:
    @staticmethod
    def create_project_id():
        session = DbSessionFactory.create_session()
        result = session.query(Project.id).order_by(desc(Project.id)).first()

        if result:
            result = result[0]
            if str(result)[:8] == datetime.datetime.now().strftime('%Y%m%d'):
                project_id = str(result)[:8] + '{:03d}'.format(int(str(result)[8:]) + 1)
                new_project = Project(id=project_id)
                session.add(new_project)
                session.commit()
                return project_id

        project_id = datetime.datetime.now().strftime('%Y%m%d') + '001'
        new_project = Project(id=project_id)
        session.add(new_project)
        session.commit()
        return project_id

    @staticmethod
    def get_project_by_id(project_id):
        session = DbSessionFactory.create_session()
        result = session.query(ServiceAwarenessProject) \
            .filter(ServiceAwarenessProject.project_id == project_id) \
            .first()

        if result:
            return result

        result = session.query(CorporateVPNProject) \
            .filter(CorporateVPNProject.project_id == project_id) \
            .first()
        if result:
            return result

    def create_serviceawareness_project(self):
        session = DbSessionFactory.create_session()
        project_id = self.create_project_id()
        new_project = ServiceAwarenessProject(project_id=project_id, mop_url='http://www.thepacketwizards.com/1')
        session.add(new_project)
        session.commit()
        return new_project

    def create_corporatevpn_project(self):
        session = DbSessionFactory.create_session()
        project_id = self.create_project_id()
        new_project = CorporateVPNProject(project_id=project_id, mop_url='http://www.thepacketwizards.com/wvpn')
        session.add(new_project)
        session.commit()
        return new_project

谢谢!

1 个答案:

答案 0 :(得分:1)

按照@IljaEveilä的建议,我只使用joined table inheritance设计了这样的表格。

class Project(SqlAlchemyBase):
    __tablename__ = 'project'

    id = Column(Integer, primary_key=True)
    created_on = Column(DateTime, default=datetime.datetime.now)
    updated_on = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
    project_url = Column(String(60))
    mop_url = Column(String(60))
    input_url = Column(String(60))
    type = Column(String(60))
    __mapper_args__ = {
        'polymorphic_identity': 'project',
        'polymorphic_on': type
    }


class ServiceAwarenessProject(Project):
    __tablename__ = 'sa_project'

    id = Column(Integer, ForeignKey('project.id'), primary_key=True)
    expiration_date = Column(DateTime)
    __mapper_args__ = {
        'polymorphic_identity': 'ServiceAwareness',
    }


class CorporateVPNProject(Project):
    __tablename__ = 'wvpn_project'

    id = Column(Integer, ForeignKey('project.id'), primary_key=True)
    client_name = Column(String(60))
    __mapper_args__ = {
        'polymorphic_identity': 'CorporateVPN',
    }

现在,要查询数据库,我必须使用with_polymorphic,这样我每行可以获得不同的表实例。

class ProjectService:
    @staticmethod
    def create_project_id():
        session = DbSessionFactory.create_session()
        result = session.query(Project.id).order_by(desc(Project.id)).first()
        print(result)
        if result:
            result = result[0]
            if str(result)[:8] == datetime.datetime.now().strftime('%Y%m%d'):
                project_id = str(result)[:8] + '{:03d}'.format(int(str(result)[8:]) + 1)
                return project_id

        project_id = datetime.datetime.now().strftime('%Y%m%d') + '001'
        return project_id

    def create_serviceawareness_project(self):
        session = DbSessionFactory.create_session()
        project_id = self.create_project_id()
        new_project = ServiceAwarenessProject(id=project_id,
                                              project_url='http://project',
                                              expiration_date=datetime.datetime.now() + datetime.timedelta(days=365),
                                              mop_url='http://mop',
                                              input_url='http://url',
                                              type='ServiceAwareness')

        session.add(new_project)
        session.commit()
        session.add(new_project)

        return new_project

    def create_corporatevpn_project(self):
        session = DbSessionFactory.create_session()
        project_id = self.create_project_id()
        new_project = CorporateVPNProject(id=project_id,
                                          project_url='http://project',
                                          client_name='TIM',
                                          mop_url='http://mop',
                                          input_url='http://url',
                                          type='CorporateVPN')

        session.add(new_project)
        session.commit()
        session.add(new_project)

        return new_project

    @staticmethod
    def get_project_by_id(project_id):
        session = DbSessionFactory.create_session()
        query = session.query(with_polymorphic(Project, [ServiceAwarenessProject, CorporateVPNProject])).filter(or_(
            ServiceAwarenessProject.id == project_id,
            CorporateVPNProject.id == project_id
        )).first()
        return query