将numpy数组存储为PyTables单元元素

时间:2018-08-09 16:23:57

标签: python arrays numpy pytables

我有4个文件,其数据具有以下格式:3个文件包含具有不同维度的numpy数组,例如20、30和25。每个文件中的记录数相同,例如10000。第四个文件包含1000浮动(与每个文件中的数组一样多)。 我尝试基于这些文件创建具有以下结构的表:

+-----------------------------------------------------------+
| VecsFile #0   | VecsFile #1   | VecsFile #2   | FloatFile |
+-----------------------------------------------------------+
|np.ndarray(20,)|np.ndarray(30,)|np.ndarray(25,)|   0.1     |
+-----------------------------------------------------------+
|np.ndarray(20,)|np.ndarray(30,)|np.ndarray(25,)|   0.2     |
                               ...

我碰到PyTables没有收到numpy数组作为单元格数据的有效类型。

代码:     导入表     将numpy导入为np

def create_table_def(n_files):
    table_def = dict()
    for rnum in range(n_files):
        table_def['VecsFile #'+str(rnum)] = tables.Col.from_atom(tables.Float64Atom())
    table_def['FloatFile'] = tables.Col.from_atom(tables.Float64Atom())

    return table_def

r0 = np.load('file0.npy')
r1 = np.load('file1.npy')
r2 = np.load('file2.npy')
s = np.random.rand(*r0.shape)


with tables.open_file('save.hdf', 'w') as saveFile:
    table_def = create_table_def(3)
    table = saveFile.create_table(saveFile.root, 'que_vectors', table_def)
    tablerow = table.row
    for i in range(r0.shape[0]):
        print(r0[i])
        tablerow['VecsFile #0'] = r0[i]
        tablerow['VecsFile #1'] = r1[i]
        tablerow['VecsFile #2'] = r2[i]
        tablerow['FloatFile'] = s[i]
        tablerow.append()
    table.flush()

我得到以下回溯:

    Traceback (most recent call last):
  File "C:/scratch_6.py", line 27, in <module>
    tablerow['VecsFile #0] = r0[i]
  File "tables\tableextension.pyx", line 1591, in tables.tableextension.Row.__setitem__
TypeError: invalid type (<class 'numpy.ndarray'>) for column ``VecsFile #0``

我做错什么了吗?还是以这种方式将带有浮点数的矢量和列存储为一个文件,而无需将所有这些矢量附加到numpy矩阵中?我想用它来追加带有向量的行,并在将来添加一个浮点数,对它们进行范围调整并删除它们。

1 个答案:

答案 0 :(得分:1)

import numpy as np
import tables as tb


class NumpyTable(tb.IsDescription):
    """ define a table with cells of 84 x 84"""
    numpy_cell = tb.Float32Col(shape=(84, 84))


""" open a file and create the table """
fileh = tb.open_file('numpy_cell.h5', mode='w')
group = fileh.create_group(fileh.root, 'group')
filters = tb.Filters(complevel=5, complib='zlib')
np_table = fileh.create_table('/group', 'numpy_table', NumpyTable, "group: NumpyTable",
                              filters=filters)

""" get the last row """
row = np_table.row

""" add a row """
row['numpy_cell'] = np.zeros((84, 84), dtype=np.float32)
row.append()

""" add another row """
row['numpy_cell'] = np.ones((84, 84), dtype=np.float32)
row.append()

""" write to disk and close the file"""
np_table.flush()
fileh.close()

""" check it """
fileh = tb.open_file('numpy_cell.h5', mode='r')
assert np.allclose(
  fileh.root.group.numpy_table[0]['numpy_cell'], 
  np.zeros((84, 84), dtype=np.float32)
)
assert np.allclose(
  fileh.root.group.numpy_table[1]['numpy_cell'], 
  np.ones((84, 84), dtype=np.float32)
)
fileh.close()