如何在一对多关系中选择“多”的“一”条件

时间:2015-07-24 09:52:35

标签: python sqlalchemy

我使用节点表来表示树,如何选择有4个或更多孩子的父母?

这是测试代码:

from sqlalchemy import Column, ForeignKey, Integer, String, create_engine
from sqlalchemy.orm import Session, relationship, backref, joinedload_all
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm.collections import attribute_mapped_collection


Base = declarative_base()

class Node(Base):
    __tablename__ = 'nodes'
    id = Column(Integer, primary_key=True)
    parent_id = Column(Integer, ForeignKey("nodes.id"))
    name = Column(String, nullable=False)

    children = relationship("Node", cascade="all, delete-orphan")
    parent = relationship("Node", remote_side=[id])

    def __init__(self, name, parent=None):
        self.name = name
        self.parent = parent

    def __repr__(self):
        return "Node(name=%r, id=%r, parent_id=%r)" % (
                    self.name,
                    self.id,
                    self.parent_id
                )

engine = create_engine('sqlite://', echo=False)
Base.metadata.create_all(engine)
session = Session(engine)

import string
nodes = {name:Node(name) for name in string.ascii_uppercase}
nodes["A"].children = [nodes[name] for name in "XYZ"]
nodes["B"].children = [nodes[name] for name in "HIJKLM"]
nodes["C"].children = [nodes[name] for name in "NPQRSTU"]
session.add_all(nodes.values())
session.commit()

然后我可以通过以下代码获得结果:

from sqlalchemy.orm import aliased
from sqlalchemy import func
n1 = aliased(Node)
n2 = aliased(Node)
q = session.query(n1.id.label("id"), func.count(n1.id).label("count")).filter(n1.id == n2.parent_id).group_by(n2.parent_id).subquery()
session.query(Node, q.c.count).filter(q.c.id == Node.id).filter(q.c.count > 4).all()

如何简化这个?类似的东西:

session.query(Node).filter(func.count(Node.children) > 4).all()

1 个答案:

答案 0 :(得分:1)

如果这是您经常需要的,请考虑使用hybrid_property。在您的代码中,它可能如下所示:

class Node(Base):
    # ...

    @hybrid_property
    def children_count(self):
        return len(self.children)

    @children_count.expression
    def children_count(cls):
        Child = aliased(Node)
        return (select([func.count(Child.id)])
                .where(Child.parent_id == cls.id).
                label('children_count')
                )

查询就像:

一样简单
q = (session.query(Node, Node.children_count)
     .filter(Node.children_count > 4)
     ).all()
print(q)