SQLAlchemy,array_agg,以及匹配输入列表

时间:2018-03-22 21:10:10

标签: python postgresql sqlalchemy

我试图更充分地使用SQLAlchemy,而不是在遇到困难的第一个迹象时回到纯SQL。在这种情况下,我在Postgres数据库(9.5)中有一个表,它通过将单个项atom_id与组标识group_id相关联,将一组整数存储为一个组。

鉴于atom_ids列表,我希望能够找出group_id所属的atom_ids group_id。仅使用atom_idatom_ids列解决此问题非常简单。

现在我试图概括一下,“组”不仅由sequence列表组成,还包括其他上下文。在下面的示例中,列表是通过包含weight列来排序的,但概念上可以使用其他列,例如atom_id列,每列def test_multi_column_grouping(self): class MultiColumnGroups(base.Base): __tablename__ = 'multi_groups' group_id = Column(Integer) atom_id = Column(Integer) sequence = Column(Integer) # arbitrary 'other' column. In this case, an integer, but it could be a float (e.g. weighting factor) base.Base.metadata.create_all(self.engine) # Insert 6 rows representing 2 different 'groups' of values vals = [ # Group 1 {'group_id': 1, 'atom_id': 1, 'sequence': 1}, {'group_id': 1, 'atom_id': 2, 'sequence': 2}, {'group_id': 1, 'atom_id': 3, 'sequence': 3}, # Group 2 {'group_id': 2, 'atom_id': 1, 'sequence': 3}, {'group_id': 2, 'atom_id': 2, 'sequence': 2}, {'group_id': 2, 'atom_id': 3, 'sequence': 1}, ] self.session.bulk_save_objects( [MultiColumnGroups(**x) for x in vals]) self.session.flush() self.assertEqual(6, len(self.session.query(MultiColumnGroups).all())) a [0,1} ]浮点值表示该组原子的'份额'。

以下是大部分演示我的问题的单元测试。

首先,一些设置:

    from collections import namedtuple
    Entity = namedtuple('Entity', ['atom_id', 'sequence'])
    values_to_match = [
        # (atom_id, sequence)
        Entity(1, 3),
        Entity(2, 2),
        Entity(3, 1),
        ]
    # The above list _should_ match with `group_id == 2`

现在,我想查询上面的表格,找出一组特定输入所属的组。我正在使用(命名)元组列表来表示查询参数。

    r = self.session.execute('''
        select group_id
        from multi_groups
        group by group_id
        having array_agg((atom_id, sequence)) = :query_tuples
        ''', {'query_tuples': values_to_match}).fetchone()
    print(r)  # > (2,)
    self.assertEqual(2, r[0])

原始SQL解决方案。我更喜欢不要依赖于此,因为本练习的一部分是学习更多SQLAlchemy。

(psycopg2.ProgrammingError) operator does not exist: record[] = integer[]

以上是将上述原始SQL解决方案直接转换为 破坏了SQLAlchemy查询。运行此命令会产生psycopg2错误:array_agg。我相信我需要将int[]投射到 from sqlalchemy import tuple_ from sqlalchemy.dialects.postgresql import array_agg existing_group = self.session.query(MultiColumnGroups).\ with_entities(MultiColumnGroups.group_id).\ group_by(MultiColumnGroups.group_id).\ having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\ one_or_none() self.assertIsNotNone(existing_group) print('|{}|'.format(existing_group)) ?只要分组列都是整数(如果需要,这是一个可接受的限制),这将起作用,但理想情况下,这将适用于混合类型的输入元组/表列。

session.query()

以上 <div class="row" v-for="row in rows"> <div v-for="column in row.columns" :class="{red: isHighlighted(row,column)}" @click.prevent="setHighlighted({row: row.id, column: column.id})"> <div>Value: {{column['value']}}</div> </div> </div> 是否关闭?我是否在这里蒙蔽了眼睛,并且错过了一些非常明显的东西,可以通过其他方式解决这个问题?

2 个答案:

答案 0 :(得分:2)

我认为你的解决方案会产生不确定的结果,因为组中的行是未指定的顺序,因此数组聚合和给定数组之间的比较可能会产生true或false,具体取决于:

[local]:5432 u@sopython*=> select group_id
[local] u@sopython- > from multi_groups 
[local] u@sopython- > group by group_id
[local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
----------
        2
(1 row)

[local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2;
UPDATE 2
[local]:5432 u@sopython*=> select group_id                                             
from multi_groups 
group by group_id
having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
----------
(0 rows)

您可以对两者应用排序,或尝试完全不同的方式:而不是数组比较,您可以使用relational division

为了划分,您必须从Entity记录列表中形成临时关系。同样,有很多方法可以解决这个问题。这是一个使用unnested数组的人:

In [112]: vtm = select([
     ...:     func.unnest(postgresql.array([
     ...:         getattr(e, f) for e in values_to_match
     ...:     ])).label(f)
     ...:     for f in Entity._fields
     ...: ]).alias()

另一个使用工会:

In [114]: vtm = union_all(*[
     ...:     select([literal(e.atom_id).label('atom_id'),
     ...:             literal(e.sequence).label('sequence')])
     ...:     for e in values_to_match
     ...: ]).alias()

临时表也可以。

根据手头的新关系,您希望找到答案,然后找到那些不存在于实体组中的multi_groups&#34;。这是一个可怕的判决,但有道理:

In [117]: mg = aliased(MultiColumnGroups)

In [119]: session.query(MultiColumnGroups.group_id).\
     ...:     filter(~exists().
     ...:         select_from(vtm).
     ...:         where(~exists().
     ...:             where(MultiColumnGroups.group_id == mg.group_id).
     ...:             where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==
     ...:                   tuple_(mg.atom_id, mg.sequence)).
     ...:             correlate_except(mg))).\
     ...:     distinct().\
     ...:     all()
     ...: 
Out[119]: [(2)]

另一方面,您也可以选择具有给定实体的组的交集:

In [19]: gs = intersect(*[
    ...:     session.query(MultiColumnGroups.group_id).
    ...:         filter(MultiColumnGroups.atom_id == vtm.atom_id,
    ...:                MultiColumnGroups.sequence == vtm.sequence)
    ...:     for vtm in values_to_match
    ...: ])

In [20]: session.execute(gs).fetchall()
Out[20]: [(2,)]

错误

ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...
                                                             ^
HINT:  No operator matches the given name and argument type(s). You might need to add explicit type casts.
 [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)

是您的values_to_match首次转换为列表列表(原因未知)然后converted to an array by your DB-API driver的结果。它产生一个整数数组的数组,而不是一个记录数组(int,int)。使用raw DB-API connection和游标,传递一个元组列表可以正常运行。

在SQLAlchemy中,如果您将列表values_to_matchsqlalchemy.dialects.postgresql.array()一起打包,它就可以正常工作,但请记住结果是不确定的。

答案 1 :(得分:2)

我发现你的答案也非常有用。由于我没有足够的声誉来评论您的解决方案,因此我会根据您的帮助发布我所做的更改。

我发现双负sql生成了一些不太理想的sql,所以我从sql向后工作以找到更清洁的东西。

这是一些简单的数据。稍微修改了示例以使用文本字段角色而不是序列字段。这应该可以推广到其他类型:

drop table if exists multi_groups;
create table multi_groups (group_id, atom_id, role) as
values
  (1, 1, 'referrer'),
  (1, 2, 'rendering'),
  (1, 3, 'attending'),
  (2, 1, 'attending'),
  (2, 2, 'rendering'),
  (2, 3, 'referrer');

原始解决方案生成的sql类似于:

select distinct
  dim_staging.multi_groups.group_id as dim_staging_multi_groups_group_id
from dim_staging.multi_groups
where not (
  exists (
    select *
    from (
           select
             unnest(
               array[1, 2, 3]
             ) as atom_id,
             unnest(
               array['referrer', 'rendering', 'attending']
             ) as role
    ) as anon_1
    where not (
      exists (
        select *
        from dim_staging.multi_groups as multi_groups_1
        where dim_staging.multi_groups.group_id = multi_groups_1.group_id
          and (anon_1.atom_id, anon_1.role) = (multi_groups_1.atom_id, multi_groups_1.role)
      )
    )
  )
);

我使用它并在sql上工作了一下:

with vtm as (
  select
    unnest(array[1, 2, 3]) as atom_id,
    unnest(array['attending', 'rendering', 'referrer']) as role
),
matched as (
  select
    dim_staging.multi_groups.group_id as group_id,
    vtm.atom_id as atom_id,
    3 as cnt
  from dim_staging.multi_groups
  full outer join vtm
    on (vtm.atom_id, vtm.role) = (dim_staging.multi_groups.atom_id, dim_staging.multi_groups.role)
)
select matched.group_id
from matched
where not (
  exists (
    select *
    from matched
    where matched.group_id is null
  )
)
group by matched.group_id
having count(1) filter (where matched.atom_id is null) = 0
  and count(1) = matched.cnt;

这是一个完整的测试脚本,用于演示如何创建上述sql

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
import os
from sqlalchemy import (
    Column,
    Integer,
    Text
)
from sqlalchemy.sql.expression import func, select, tuple_, exists, join, literal, label
from sqlalchemy.dialects import postgresql
from collections import namedtuple


db_url = os.getenv('DB_URL', 'postgresql://localhost:5432/dw')
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)
session = Session()


Base = declarative_base()


class MultiColumnGroups(Base):
    __tablename__ = 'multi_groups'
    id = Column(Integer, primary_key=True)
    group_id = Column(Integer)
    atom_id = Column(Integer)
    role = Column(Text)


Base.metadata.drop_all(engine, [MultiColumnGroups.__table__])
Base.metadata.create_all(engine, [MultiColumnGroups.__table__])

vals = [
    # Group 1
    {'group_id': 1, 'atom_id': 1, 'role': 'referrer'},
    {'group_id': 1, 'atom_id': 2, 'role': 'rendering'},
    {'group_id': 1, 'atom_id': 3, 'role': 'attending'},
    # Group 2
    {'group_id': 2, 'atom_id': 1, 'role': 'attending'},
    {'group_id': 2, 'atom_id': 2, 'role': 'rendering'},
    {'group_id': 2, 'atom_id': 3, 'role': 'referrer'},
]

session.bulk_save_objects(
    [MultiColumnGroups(**x) for x in vals]
)
session.commit()

Entity = namedtuple('Entity', ['atom_id', 'role'])
values_to_match = [
    # (atom_id, role)
    # Entity(1, 'referrer'),
    # Entity(2, 'rendering'),
    # Entity(3, 'attending'),
    Entity(1, 'attending'),
    Entity(2, 'rendering'),
    Entity(3, 'referrer'),
]

vtm = select(
    [
        func.unnest(
            postgresql.array([
                getattr(e, f) for e in values_to_match
                ]
            )
        ).label(f)
        for f in Entity._fields
    ]
).cte(name='vtm')

j = join(
    MultiColumnGroups, vtm,
    tuple_(vtm.c.atom_id, vtm.c.role) == tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.role),
    full=True
)
matched = select([
  MultiColumnGroups.group_id,
  vtm.c.atom_id,
  label(
    'cnt',
    literal(len(values_to_match),type_=Integer
   )
)]).select_from(j).cte(name='matched')

group_id = session.query(matched.c.group_id).\
    filter(
        ~exists().
        select_from(matched).
        where(matched.c.group_id == None)
    ).\
    group_by(matched.c.group_id).\
    having(func.count(1).filter(matched.c.atom_id == None) == 0).\
    having(func.count(1) == matched.c.cnt).one().group_id

print(group_id)

编辑:通过在查询中包括被比较的值的数量作为计数并且检查匹配的分组计数等于值的数量来解决是否存在导致多个匹配的子组的情况。对不起有疏忽。