假设我有一个简单的博客系统模型:
class Blog(db.Model):
__tablename__ = 'blog'
id = db.Column(db.BigInteger, primary_key=True)
posts = db.relationship('Post', backref='blog', lazy='dynamic')
class Post(db.Model):
__tablename__ = 'post'
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String())
我正在使用Flask-SQLAlchemy,因此语法可能看起来有点不同,但映射非常简单。
是否可以在帖子计数的同时获取所有博客?因为现在迭代列表会在每次迭代时执行额外的查询,并且不会真正扩展。它在SQL中非常简单:
SELECT b.*, COUNT(p)
FROM blog b
JOIN post p ON(b.id = p.blog_id)
GROUP BY b.id;
但是如何在SQLAlchemy中完成?我希望有两种选择:
我认为我错过了正确的术语,因为我找不到谷歌的任何内容。
答案 0 :(得分:0)
见下文,代码应该回答你的两个问题。对第二部分的回答使用Hybrid Extension。
class Blog(db.Model):
__tablename__ = 'blog'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String)
posts = db.relationship('Post', backref='blog', lazy='dynamic')
@hybrid_property
def num_posts(self):
# return len(self.posts) # use if relationship is *not* dynamic
return self.posts.count()
@num_posts.expression
def _num_posts_expression(cls):
return (db.select([db.func.count(Post.id)])
.where(Post.post_id == cls.id)
.label("num_posts")
)
class Post(db.Model):
__tablename__ = 'post'
id = db.Column(db.Integer, primary_key=True)
post_id = db.Column(db.ForeignKey(Blog.id))
title = db.Column(db.String())
def test():
with app.app_context():
db.drop_all()
db.create_all()
def create_test_data():
blogs = [
Blog(name="empty",),
Blog(name="geo",
posts=[
Post(title='west'),
Post(title='east'),
Post(title='north'),
Post(title='south'),
]),
Blog(name="food",
posts=[
Post(title='sour'),
Post(title='sweet'),
]),
]
db.session.add_all(blogs)
db.session.commit()
create_test_data()
# simple query
q = (db.session.query(Blog, db.func.count(Post.id).label("num_posts"))
.outerjoin(Post, Blog.posts)
.group_by(Blog.id)
)
for b, num_posts in q:
print(b, num_posts)
print("-"*80)
# Using hybrid_attribute
# num_posts is fetched later
q = (db.session.query(Blog))
for b in q:
print(b)
print(b.num_posts)
print("-"*80)
# num_posts is fetched with the query
q = (db.session.query(Blog, Blog.num_posts))
for b in q:
print(b)
print(b.num_posts)
print("-"*80)