SQLAlchemy 查询常用属性?

时间:2021-01-09 15:23:42

标签: python sqlalchemy

我有一个电影和属于这些电影的类别的关联表,我想获取两部电影之间的所有常见类别(只需要类别的 id)。

因此,如果两部电影的类别 'Thriller' 的 category_id 为 5,我想得到 5。如果它们没有共同的类别,则只返回 None。

表格看起来像:

class MovieCategoryScores(db.Model):
    movie_id = db.Column(db.Integer, db.ForeignKey('movie.id'), primary_key=True)
    category_id = db.Column(db.Integer, db.ForeignKey('category.id'), primary_key=True)
    score = db.Column(db.Integer)
    votes = db.Column(db.Integer)
    category = relationship("Category", back_populates="movies")
    movie = relationship("Movie", back_populates="categories")

我知道我可以查询 categories = MovieCategoryScores.query.filter(MovieCategoryScores.movie_id.in_([movie1, movie2])).all() 获取所有类别,我尝试在查询后放置 (MovieCategoryScores.category_id) 以仅获取 ID,但这不起作用,我刚刚收到 TypeError: 'BaseQuery' object is not callable 错误。

如果我想出了如何只是获取 ID,我可以使用以下方法:

categories.sort()
for index, category_id in enumerate(categories.copy()):
    if categories[index+1] != category_id:
        categories[index].remove()
return categories

要获取仅包含 2 个 ID 的列表,但感觉应该有一些更好的方法可以仅通过查询命令获取具有相同 category_id 的项目的 ID?

任何一种解决方案都将不胜感激!

2 个答案:

答案 0 :(得分:1)

您可以使用 havingfunc.count() > 1 来获得 distinct 的反义词(group_by 需要 having)。

from sqlalchemy import func
categories = MovieCategoryScores.query.with_entities(MovieCategoryScores.category_id).filter(MovieCategoryScores.movie_id.in_([movie1, movie2])).group_by(MovieCategoryScores.category_id).having(func.count(MovieCategoryScores.category_id) > 1).all()

或者,如果您想检索 Category.name,您可以执行以下操作:

from sqlalchemy import func
categories = MovieCategoryScores.query.with_entities(MovieCategoryScores.name).filter(MovieCategoryScores.movie_id.in_([movie1, movie2])).group_by(MovieCategoryScores.category).having(func.count(MovieCategoryScores.category) > 1).all()

答案 1 :(得分:0)

好吧,在经历了很多头痛之后,我想出了这个:

def get_common_categories(movie1, movie2):
    categories = MovieCategoryScores.query.with_entities(MovieCategoryScores.category_id).filter(MovieCategoryScores.movie_id.in_([movie1, movie2])).all()
    categories.sort()
    common = []
    for index, category in enumerate(categories):
        if index+1 < len(categories) and categories[index+1] == category:
            common.append(category[0])
    return common

感觉很愚蠢,就像应该有某种方法可以只使用查询和过滤器来完成,但无法弄清楚,所以现在必须这样做。

相关问题