使用SQLAlchemy ORM查询其他数据

时间:2015-02-12 23:12:34

标签: python sqlalchemy flask-sqlalchemy

假设我有一个简单的博客系统模型:

class Blog(db.Model):

    __tablename__ = 'blog'

    id = db.Column(db.BigInteger, primary_key=True)
    posts = db.relationship('Post', backref='blog', lazy='dynamic')

class Post(db.Model):

    __tablename__ = 'post'

    id = db.Column(db.Integer, primary_key=True)
    title = db.Column(db.String())

我正在使用Flask-SQLAlchemy,因此语法可能看起来有点不同,但映射非常简单。

是否可以在帖子计数的同时获取所有博客?因为现在迭代列表会在每次迭代时执行额外的查询,并且不会真正扩展。它在SQL中非常简单:

SELECT b.*, COUNT(p)
FROM blog b
    JOIN post p ON(b.id = p.blog_id)
GROUP BY b.id;

但是如何在SQLAlchemy中完成?我希望有两种选择:

  • 在查询时指定
  • 在模型中创建一个“虚拟”属性以便始终使用它(例如Hibernate调用此formulas

我认为我错过了正确的术语,因为我找不到谷歌的任何内容。

1 个答案:

答案 0 :(得分:0)

见下文,代码应该回答你的两个问题。对第二部分的回答使用Hybrid Extension

class Blog(db.Model):
    __tablename__ = 'blog'

    id = db.Column(db.Integer, primary_key=True)
    name = db.Column(db.String)
    posts = db.relationship('Post', backref='blog', lazy='dynamic')

    @hybrid_property
    def num_posts(self):
        # return len(self.posts)  # use if relationship is *not* dynamic
        return self.posts.count()

    @num_posts.expression
    def _num_posts_expression(cls):
        return (db.select([db.func.count(Post.id)])
                .where(Post.post_id == cls.id)
                .label("num_posts")
                )


class Post(db.Model):
    __tablename__ = 'post'

    id = db.Column(db.Integer, primary_key=True)
    post_id = db.Column(db.ForeignKey(Blog.id))
    title = db.Column(db.String())


def test():
    with app.app_context():
        db.drop_all()
        db.create_all()

        def create_test_data():
            blogs = [
                Blog(name="empty",),
                Blog(name="geo",
                     posts=[
                         Post(title='west'),
                         Post(title='east'),
                         Post(title='north'),
                         Post(title='south'),
                     ]),
                Blog(name="food",
                     posts=[
                         Post(title='sour'),
                         Post(title='sweet'),
                     ]),
            ]
            db.session.add_all(blogs)
            db.session.commit()

        create_test_data()


        # simple query
        q = (db.session.query(Blog, db.func.count(Post.id).label("num_posts"))
             .outerjoin(Post, Blog.posts)
             .group_by(Blog.id)
             )
        for b, num_posts in q:
            print(b, num_posts)
        print("-"*80)


        # Using hybrid_attribute
        # num_posts is fetched later
        q = (db.session.query(Blog))
        for b in q:
            print(b)
            print(b.num_posts)
        print("-"*80)

        # num_posts is fetched with the query
        q = (db.session.query(Blog, Blog.num_posts))
        for b in q:
            print(b)
            print(b.num_posts)
        print("-"*80)