使用`Table.where()`检索与条件匹配的PyTables表的行的索引

时间:2016-09-26 21:22:45

标签: numpy pytables

我需要匹配表中给定条件的行的索引(作为numpy数组)(具有数十亿行),这是我目前在我的代码中使用的行,它可以工作,但是非常难看:

indices = np.array([row.nrow for row in the_table.where("foo == 42")])

这也需要半分钟,我确信列表创建是其中一个原因。

我找不到一个优雅的解决方案,而且我还在努力研究pytables文档,那么有没有人知道任何神奇的方式来做得更漂亮,也许还要快一点?也许有一些特殊的查询关键字我缺少,因为我觉得pytables应该能够将匹配的行索引作为numpy数组返回。

2 个答案:

答案 0 :(得分:1)

tables.Table.get_where_list()给出与给定条件匹配的行的索引

答案 1 :(得分:0)

我读了pytables的来源,where()在Cython中实现,但似乎不够快。这是一个可以加速的复杂方法:

首先创建一些数据:

from tables import *
import numpy as np

class Particle(IsDescription):
    name      = StringCol(16)   # 16-character String
    idnumber  = Int64Col()      # Signed 64-bit integer
    ADCcount  = UInt16Col()     # Unsigned short integer
    TDCcount  = UInt8Col()      # unsigned byte
    grid_i    = Int32Col()      # 32-bit integer
    grid_j    = Int32Col()      # 32-bit integer
    pressure  = Float32Col()    # float  (single-precision)
    energy    = Float64Col()    # double (double-precision)
h5file = open_file("tutorial1.h5", mode = "w", title = "Test file")
group = h5file.create_group("/", 'detector', 'Detector information')
table = h5file.create_table(group, 'readout', Particle, "Readout example")
particle = table.row
for i in range(1001000):
    particle['name']  = 'Particle: %6d' % (i)
    particle['TDCcount'] = i % 256
    particle['ADCcount'] = (i * 256) % (1 << 16)
    particle['grid_i'] = i
    particle['grid_j'] = 10 - i
    particle['pressure'] = float(i*i)
    particle['energy'] = float(particle['pressure'] ** 4)
    particle['idnumber'] = i * (2 ** 34)
    # Insert a new particle record
    particle.append()

table.flush()
h5file.close()

以块的形式读取列并将索引附加到列表中,最后将列表连接到数组。您可以根据内存大小更改块大小:

h5file = open_file("tutorial1.h5")

table = h5file.get_node("/detector/readout")

size = 10000
col = "energy"
buf = np.zeros(batch, dtype=table.coldtypes[col])
res = []
for start in range(0, table.nrows, size):
    length = min(size, table.nrows - start)
    data = table.read(start, start + batch, field=col, out=buf[:length])
    tmp = np.where(data > 10000)[0]
    tmp += start
    res.append(tmp)
res = np.concatenate(res)