SQLAlchemy:提前在关系列上动态应用过滤器

时间:2018-07-22 08:58:49

标签: python sqlalchemy django-rest-framework

我有两个表示多对一关系的SQLAlchemy类,例如:

class Person
    id = Column(Integer(), primary_key=True)
    name = Column(String(30))
    known_addresses = relationship('Address', backref='person')

class Address:
    id = Column(Integer(), primary_key=True)
    person_id = Column(Integer(), ForeignKey(Person.id, ondelete='cascade'))
    city = Column(String(30))
    zip_code = Column(String(10))

现在,说我有一个函数,该函数返回按邮政编码过滤的Person查询集(一个Select对象):

def get_persons_in_zip_code(zip_code):
    return session.query(Person).\
        join(Address).\
        where(Address.zip_code == zip_code)

一旦返回查询集,我将无法对其进行控制,并且预计它将封装我正在使用的框架(在我的情况下为Django / DRF)以呈现人员列表的所有数据以及他们的地址(因此代码会迭代查询集,为每个人调用.addresses并对其进行渲染)。

这很重要:我想确保调用.addresses只会返回 与原始zip_code过滤查询中匹配的地址-并非与该人相关的所有地址。

是否有一种方法可以在SQLAlchemy 中实现而无需访问稍后返回的Person对象?也就是说,我只能修改我的get_persons_in_zip_code函数或原始SQLAlchemy类,但是无法访问从查询返回的Person对象,因为这种情况发生在框架渲染代码的深处。

编辑:同样重要的是,在返回的查询对象上调用count()会产生预期的Person对象数,而不是Address对象数。

1 个答案:

答案 0 :(得分:0)

您似乎正在寻找contains_eager

编辑:一个更新的版本,可对.count()函数进行修补,以仅返回不同的Person计数。

from sqlalchemy import Integer, Column, String, ForeignKey
from sqlalchemy import create_engine, func, distinct
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker, contains_eager
from types import MethodType

engine = create_engine('sqlite:///:memory:', echo=True)
Session = sessionmaker(bind=engine)
session = Session()

Base = declarative_base()


class Person(Base):
    __tablename__ = "person"
    id = Column(Integer(), primary_key=True)
    name = Column(String(30))
    known_addresses = relationship('Address', backref='person')

    def __repr__(self):
        return "<Person {}>".format(self.name)


class Address(Base):
    __tablename__ = "address"
    id = Column(Integer(), primary_key=True)
    person_id = Column(Integer(), ForeignKey(Person.id, ondelete='cascade'))
    city = Column(String(30))
    zip_code = Column(String(10))

    def __repr__(self):
        return "<Address {}>".format(self.zip_code)


Base.metadata.create_all(engine)

p1 = Person(name="P1")
session.add(p1)
p2 = Person(name="P2")
session.add(p2)

session.commit()

a1 = Address(person_id=p1.id, zip_code="123")
session.add(a1)
a2 = Address(person_id=p1.id, zip_code="345")
session.add(a2)
a3 = Address(person_id=p2.id, zip_code="123")
session.add(a3)
a4 = Address(person_id=p1.id, zip_code="123")
session.add(a4)

session.commit()

def get_persons_in_zip_code(zip_code):
    return session.query(Person).\
        join(Person.known_addresses).\
        filter(Address.zip_code == zip_code).\
        options(contains_eager(Person.known_addresses))

def distinct_person_count(q):
    count_q = q.statement.with_only_columns([func.count(distinct(Person.id))])
    return q.session.execute(count_q).scalar()


results = get_persons_in_zip_code("123")
results.count = MethodType(distinct_person_count, results)

print(results.count())


for person in results:
    print(person)
    for address in person.known_addresses:
        print(address)

输出:

2
<Person P1>
<Address 123>
<Address 123>
<Person P2>
<Address 123>