sqlalchemy中的多对多交叉点

时间:2011-07-17 01:07:10

标签: python sqlalchemy

我有一个带有.tags属性的Character类; .tags属性是Tag对象的列表。在多对多的关系中。我正在尝试编写一个查询,该查询将查找所有不具有相同名称且至少有一个共同标记的字符对;我该怎么做呢?

1 个答案:

答案 0 :(得分:2)

你可以这样做:

  1. 想一个可以提供所需结果的SQL查询
  2. 创建相应的SA查询
  3. 为了测试数据,SQL查询(SQL Server上带有WITH子句)如下所示(显然你的表名和列名可能不同):

    WITH t_character (id, name)
    AS (    SELECT  1, "ch-1"
    UNION   SELECT  2, "ch-2"
    UNION   SELECT  3, "ch-3"
    UNION   SELECT  4, "ch-4"
    )
    , t_tag (id, name)
    AS (    SELECT  1, "tag-1"
    UNION   SELECT  2, "tag-2"
    UNION   SELECT  3, "tag-3"
    )
    , t_character_tag (character_id, tag_id)
    AS (    SELECT  1, 1
    UNION   SELECT  2, 1
    UNION   SELECT  2, 2
    UNION   SELECT  3, 1
    UNION   SELECT  3, 2
    UNION   SELECT  3, 3
    UNION   SELECT  4, 3
    )
    -- the result should contain pairs (1, 2), (1, 3), (2, 3) again (2, 3), and (3, 4)
    SELECT      DISTINCT -- will filter out duplicates
                c1.id, c2.id
    FROM        t_character c1
    INNER JOIN  t_character c2
            ON  c1.id < c2.id -- all pairs without duplicates
    INNER JOIN  t_character_tag r1
            ON  r1.character_id = c1.id
    INNER JOIN  t_character_tag r2
            ON  r2.character_id = c2.id
    WHERE       r1.tag_id = r2.tag_id
    ORDER BY    c1.id, c2.id
    

    包含您需要的查询的完整示例代码如下:

    from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Table
    from sqlalchemy.orm import relationship, scoped_session, sessionmaker, aliased
    from sqlalchemy.ext.declarative import declarative_base
    
    # Configure test database for SA
    engine = create_engine("sqlite:///:memory:", echo=False)
    session = scoped_session(sessionmaker(bind=engine, autoflush=False))
    
    class Base(object):
        """ Just a helper base class to set properties on object creation.
        Also provides a convenient default __repr__() function, but be aware that 
        also relationships are printed, which might result in loading relations.
        """
        def __init__(self, **kwargs):
            for k,v in kwargs.items():
                setattr(self, k, v)
    
        def __repr__(self):
            return "<%s(%s)>" % (self.__class__.__name__, 
                ", ".join("%s=%r" % (k, self.__dict__[k]) 
                    for k in sorted(self.__dict__) if "_sa_" != k[:4] and "_backref_" != k[:9])
                )
    Base = declarative_base(cls=Base)
    
    t_character_tag = Table(
        "t_character_tag", Base.metadata,
        Column("character_id", Integer, ForeignKey("t_character.id")),
        Column("tag_id", Integer, ForeignKey("t_tag.id"))
        )
    
    class Character(Base):
        __tablename__ = u"t_character"
        id = Column(Integer, primary_key=True)
        name = Column(String)
        tags = relationship("Tag", secondary=t_character_tag, backref="characters")
    
    class Tag(Base):
        __tablename__ = u"t_tag"
        id = Column(Integer, primary_key=True)
        name = Column(String)
    
    # create db schema
    Base.metadata.create_all(engine)
    
    
    # 0. create test data
    ch1 = Character(id=1, name="ch-1")
    ch2 = Character(id=2, name="ch-2")
    ch3 = Character(id=3, name="ch-3")
    ch4 = Character(id=4, name="ch-4")
    ta1 = Tag(id=1, name="tag-1")
    ta2 = Tag(id=2, name="tag-2")
    ta3 = Tag(id=3, name="tag-3")
    ch1.tags.append(ta1)
    ch2.tags.append(ta1); ch2.tags.append(ta2);
    ch3.tags.append(ta1); ch3.tags.append(ta2); ch3.tags.append(ta3);
    ch4.tags.append(ta3)
    session.add_all((ch1, ch2, ch3, ch4,))
    session.commit()
    
    # 1. some data checks
    session.expunge_all()
    assert len(session.query(Character).all()) == 4
    assert session.query(Tag).get(2).name == "tag-2"
    assert len(session.query(Character).get(3).tags) == 3
    
    # 2. create a final query (THE ANSWER TO THE QUESTION):
    session.expunge_all()
    t_c1 = aliased(Character)
    t_c2 = aliased(Character)
    t_t1 = aliased(Tag)
    t_t2 = aliased(Tag)
    q =(session.query(t_c1, t_c2).
        join((t_c2, t_c1.id < t_c2.id)).
        join((t_t1, t_c1.tags)).
        join((t_t2, t_c2.tags)).
        filter(t_t1.id == t_t2.id).
        filter(t_c1.name != t_c2.name). # if tag name is unique, this can be dropped
        order_by(t_c1.id).
        order_by(t_c2.id)
        )
    q = q.distinct() # filter out duplicates
    res = [_r for _r in q.all()]
    assert len(res) == 4
    for _r in res:
        print _r