使用sqlalchemy orm过滤db列中的一系列值

时间:2014-12-11 20:18:55

标签: postgresql orm sqlalchemy subset

我有一个postgresql数据库,在一个特定的表中有很多行。该表中的一列称为数据,是一个浮点数组REAL [],并且填充了~4500个元素的数组。我想通过SQLAlchemy和ORM通过一些查询来访问此表。

如何选择表格中该列子集满足某些条件的所有行,例如是否包含一系列值?就像我想要选择数据包含值> = 10的所有行,或者> = 10和< = 20之间的值。

我可以使用像

这样的直接会话查询来完成此操作
rows = session.query(Table).filter(Table.data.(some conditional)).all()

我的条件类似于" VALUES> = 10和VALUES< = 20"?

或者,当我定义SQLAlchemy表类时,是否需要定义一些特殊方法或设置。例如,我将我的表设置为

class Table(Base):
    __tablename__ = 'table'
    __table_args__ = {'autoload' : True, 'schema' : 'testdb', 'extend_existing':True}

    data = deferred(Column(ARRAY(Float)))

    def __repr__(self):
        return '<Table (pk={0})>'.format(self.pk)       

理想情况下,我想设置它,以便我可以在session.query调用中进行简单的过滤。这可能吗?我对ORM并不是很熟悉,所以可能是这样吗?

我已经看过了ARRAY Comparator sqlalchemy文档,但那些似乎只是在精确值上工作。我的数据精确到6个sigfigs,我不会提前知道确切的值。

最好的方法是什么?谢谢。

编辑:

根据以下评论,以下是我尝试选择所有包含数据(来自1列)&gt; = 1.0的行(1000个中)的代码。应该有537行。

rows =  session.query(datadb.Table).filter(datadb.Table.data.any(1.0,operator=operators.le)).all()

这给出了正确的子集编号。 len(行)= 537.但是,我不理解使用此运算符的逻辑,在哪里选择数据&gt; = 1.0,我使用le运算符?此外,沿着这些相同的行,应该有234行,其值在&gt; = 1.0和&lt; 1.0之间,但是这个语句无法给出正确的子集..

rows = session.query(datadb.Table).filter(datadb.Table.data.any(1.0,operator=operators.le)).filter(datadb.Table.data.any(1.2,operator=operators.ge)).all()

*编辑2 *

这是我的数据库表的示例,其中包含几行。 pk是一个整数,数据是真实的[]。

db datadb
schema Table 
pk      data
0       [0.0,0.0,0.5,0.3,1.3,1.9,0.3,0.0,0.0]
1       [0.1,0.0,1.0,0.7,1.1,1.5,1.2,0.3,1.4]
2       [0.0,0.6,0.4,0.3,1.6,1.7,0.4,1.3,0.0]
3       [0.0,0.1,0.2,0.4,1.0,1.1,1.2,0.9,0.0]
4       [0.0,0.0,0.5,0.3,0.2,0.1,0.7,0.3,0.1]

我有5行,其中4行的数据值> = 1.0,而只有2行的值在&gt; = 1.0和&lt; = 1.2之间。我想要抓取行的查询是在第一种情况下

rows = session.query(datadb.Table).filter(datadb.Table.data.any(1.0,operator=operators.le)).all()

这应该返回4行,pk = 0,1,2,3。这个查询符合我的期望。第二种情况

rows = session.query(datadb.Table).filter(datadb.Table.data.any(1.0,operator=operators.le)).filter(datadb.Table.data.any(1.2,operator=operators.ge)).all()

并且应该在pk = 1,3处返回2行。但是,此查询只返回第一个查询中的4行。对于第二个查询,我也尝试了

rows = session.query(datadb.Table).filter(datadb.Table.data.any(1.0,operator=operators.le),datadb.Table.data.any(1.2,operator=operators.ge)).all()

也没有用。

1 个答案:

答案 0 :(得分:0)

请阅读ARRAY.Comparator上的文档,根据该文档,您应该能够执行以下操作:

rows = (session.query(Table)
        .filter(Table.data.any(10, operator=operators.le))
        .filter(Table.data.any(20, operator=operators.ge)
).all()

修改

# combined filter does not work,
# but applying one or the other is still useful as it reduces the result set
q = (session.query(MyTable)
     .filter(MyTable.data.any(1.0, operator=operators.le))
     # .filter(MyTable.data.any(1.2, operator=operators.ge))
     )

# filter in memory
items = [_row for _row in q.all()
         if any(1.0 <= item <= 1.2 for item in _row.data)]

for item in items:
    print(item)