点乘不同dtype的大密集矩阵(float x boolean)

时间:2018-01-13 23:45:03

标签: python numpy matrix type-conversion sparse-matrix

我将2个矩阵乘以A.dot(B),其中:

A = 1 x n矩阵,dtype float

B = n x n矩阵,dtype布尔

我正在对大n执行此计算,并且内存耗尽非常快(大约n = 14000失败)。 A和B很密集。

看来原因是numpy在执行矩阵乘法之前将B转换为dtype float,因此会产生巨大的内存成本。事实上,%timeit表明它花费更多时间将B转换为浮动而不是执行乘法。

有没有办法解决这个问题?这里强调的是减少内存尖峰/浮点转换,同时仍然允许常见的矩阵功能(矩阵加法/乘法)。

以下是基准测试解决方案的可重现数据:

np.random.seed(999)
n = 30000
A = np.random.random(n)
B = np.where(np.random.random((n, n)) > 0.5, True, False)

1 个答案:

答案 0 :(得分:3)

您可以使用np.packbits然后在行上np.bincount来节省将布尔数组压缩到位域的空间和时间,以同时计算8个标量积的块。

import numpy as np

def setup_data(M, N):
    return {'B': np.random.randint(0, 2, (M, N), dtype=bool),
            'A': np.random.random((M,))}

def f_vecmat_mult(A, B, decode=np.array(np.unravel_index(np.arange(256), 8*(2,)))):
    M, N = B.shape
    out = [(decode * np.bincount(row, A, minlength=256)).sum(axis=1) for row in np.packbits(B, axis=1).T]
    if N & 7:
        out[-1] = out[-1][:N & 7]
    return np.concatenate(out)

def f_direct(A, B):
    return A @ B

import types
from timeit import timeit

for M, N in [(99, 80), (999, 777), (9999, 7777), (30000, 30000)]:
    data = setup_data(M, N)
    ref = f_vecmat_mult(**data)
    print(f'M, N = {M}, {N}')
    for name, func in list(globals().items()):
        if not name.startswith('f_') or not isinstance(func, types.FunctionType):
            continue
        try:
            assert np.allclose(ref, func(**data))
            print("{:16s}{:16.8f} ms".format(name[2:], timeit(
                'f(**data)', globals={'f':func, 'data':data}, number=100)*10))
        except:
            print("{:16s} apparently failed".format(name[2:]))

示例输出:

M, N = 99, 80
vecmat_mult           0.12248290 ms
direct                0.03647798 ms
M, N = 999, 777
vecmat_mult           1.67854790 ms
direct                5.68286091 ms
M, N = 9999, 7777
vecmat_mult          68.74523309 ms
direct              571.34140913 ms
M, N = 30000, 30000
vecmat_mult        1345.18991556 ms
direct           apparently failed