在Python中创建快速RGB查找表

时间:2018-09-24 20:25:20

标签: python performance numpy opencv lookup-tables

我有一个函数,称为“ rgb2something”,该函数将RGB数据[1x1x3]转换为单个值(概率),循环遍历输入RGB数据中的每个像素非常慢。

我尝试了以下方法来加快转换速度。生成LUT(查找表):

import numpy as np

levels = 256
levels2 = levels**2
lut = [0] * (levels ** 3)

levels_range = range(0, levels)

for r in levels_range:
    for g in levels_range:
        for b in levels_range:
            lut[r + (g * levels) + (b * levels2)] = rgb2something(r, g, b)

并将RGB转换为变换后的概率图像:

result = np.take(lut, r_channel + (g_channel * 256) + (b_channel * 65536))

但是生成LUT和计算结果仍然很慢。在2维中它相当快,但是在3维(r,g和b)中它很慢。如何提高其性能?

编辑

rgb2something(r, g, b)看起来像这样:

def rgb2something(r, g, b):
    y = np.array([[r, g, b]])
    y_mean = np.mean(y, axis=0)
    y_centered = y - y_mean
    y_cov = y_centered.T.dot(y_centered) / len(y_centered)
    m = len(Consts.x)
    n = len(y)
    q = m + n
    pool_cov = (m / q * x_cov) + (n / q * y_cov)
    inv_pool_cov = np.linalg.inv(pool_cov)
    g = Consts.x_mean - y_mean
    mah = g.T.dot(inv_pool_cov).dot(g) ** 0.5
    return mah

编辑2:

我正在尝试使用OpenCV来获得完整的工作代码示例,因此欢迎使用Apply LUT之类的任何OpenCV方法,以及C / C ++方法:

import matplotlib.pyplot as plt
import numpy as np 
import cv2

class Model:
    x = np.array([
        [6, 5, 2],
        [2, 5, 7],
        [6, 3, 1]
    ])
    x_mean = np.mean(x, axis=0)
    x_centered = x - x_mean
    x_covariance = x_centered.T.dot(x_centered) / len(x_centered)
    m = len(x)
    n = 1  # Only ever comparing to a single pixel
    q = m + n
    pooled_covariance = (m / q * x_covariance)  # + (n / q * y_cov) -< Always 0 for a single point
    inverse_pooled_covariance = np.linalg.inv(pooled_covariance)

def rgb2something(r, g, b):
    #Calculates Mahalanobis Distance between pixel and model X
    y = np.array([[r, g, b]])
    y_mean = np.mean(y, axis=0)
    g = Model.x_mean - y_mean
    mah = g.T.dot(Model.inverse_pooled_covariance).dot(g) ** 0.5
    return mah

def generate_lut():
    levels = 256
    levels2 = levels**2
    lut = [0] * (levels ** 3)

    levels_range = range(0, levels)

    for r in levels_range:
        for g in levels_range:
            for b in levels_range:
                lut[r + (g * levels) + (b * levels2)] = rgb2something(r, g, b)

    return lut

def calculate_distance(lut, input_image):
    return np.take(lut, input_image[:, :, 0] + (input_image[:, :, 1] * 256) + (input_image[:, :, 2] * 65536))

lut = generate_lut()
rgb = np.random.randint(255, size=(1080, 1920, 3), dtype=np.uint8)
result = calculate_distance(lut, rgb)

cv2.imshow("Example", rgb)
cv2.imshow("Result", result)
cv2.waitKey(0)

2 个答案:

答案 0 :(得分:3)

更新:添加了blas优化

有几种简单有效的优化方法:

(1)向量化,向量化!对这段代码中的所有内容进行矢量化并不是很困难。见下文。

(2)使用适当的查找,即花式索引,而不是np.take

(3)使用Cholesky反压缩。使用blas dtrmm,我们可以利用其三角形结构

这是代码。只需将其添加到OP的代码末尾(在EDIT 2下)。除非您非常有耐心,否则您可能还想注释掉lut = generate_lut()result = calculate_distance(lut, rgb)行以及对cv2的所有引用。我还向x添加了一个随机行,以使其协方差矩阵变为非奇数。

class Full_Model(Model):
    ch = np.linalg.cholesky(Model.inverse_pooled_covariance)
    chx = Model.x_mean@ch

def rgb2something_vectorized(rgb):
    return np.sqrt(np.sum(((rgb - Full_Model.x_mean)@Full_Model.ch)**2,  axis=-1))

from scipy.linalg import blas

def rgb2something_blas(rgb):
    *shp, nchan = rgb.shape
    return np.sqrt(np.einsum('...i,...i', *2*(blas.dtrmm(1, Full_Model.ch.T, rgb.reshape(-1, nchan).T, 0, 0, 0, 0, 0).T - Full_Model.chx,))).reshape(shp)

def generate_lut_vectorized():
    return rgb2something_vectorized(np.transpose(np.indices((256, 256, 256))))

def generate_lut_blas():
    rng = np.arange(256)
    arr = np.empty((256, 256, 256, 3))
    arr[0, ..., 0]  = rng
    arr[0, ..., 1]  = rng[:, None]
    arr[1:, ...] = arr[0]
    arr[..., 2] = rng[:, None, None]
    return rgb2something_blas(arr)

def calculate_distance_vectorized(lut, input_image):
    return lut[input_image[..., 2], input_image[..., 1], input_image[..., 0]]

# test code

def random_check_lut(lut):
    """Because the original lut generator is excruciatingly slow,
    we only compare a random sample, using the original code
    """
    levels = 256
    levels2 = levels**2
    lut = lut.ravel()

    levels_range = range(0, levels)

    for r, g, b in np.random.randint(0, 256, (1000, 3)):
        assert np.isclose(lut[r + (g * levels) + (b * levels2)], rgb2something(r, g, b))

import time
td = []
td.append((time.time(), 'create lut vectorized'))
lutv = generate_lut_vectorized()
td.append((time.time(), 'create lut using blas'))
lutb = generate_lut_blas()
td.append((time.time(), 'lookup using np.take'))
res = calculate_distance(lutv, rgb)
td.append((time.time(), 'process on the fly (no lookup)'))
resotf = rgb2something_vectorized(rgb)
td.append((time.time(), 'process on the fly (blas)'))
resbla = rgb2something_blas(rgb)
td.append((time.time(), 'lookup using fancy indexing'))
resv = calculate_distance_vectorized(lutv, rgb)
td.append((time.time(), None))

print("sanity checks ... ", end='')
assert np.allclose(res, resotf) and np.allclose(res, resv) \
    and np.allclose(res, resbla) and np.allclose(lutv, lutb)
random_check_lut(lutv)
print('all ok\n')

t, d = zip(*td)
for ti, di in zip(np.diff(t), d):
    print(f'{di:32s} {ti:10.3f} seconds')

样品运行:

sanity checks ... all ok

create lut vectorized                 1.116 seconds
create lut using blas                 0.917 seconds
lookup using np.take                  0.398 seconds
process on the fly (no lookup)        0.127 seconds
process on the fly (blas)             0.069 seconds
lookup using fancy indexing           0.064 seconds

我们可以看到,最好的查找要比晶须上的实时计算要好。那就是说该示例可能高估了查找成本,因为随机像素大概比自然图像对缓存的友好度低。

原始答案(也许对某些人还是有用的)

如果无法对rgb2something进行矢量化处理,并且您想处理一张典型图像,那么可以使用np.unique获得不错的加速效果。

如果rgb2something很昂贵并且必须处理多个图像,则可以将unique与缓存结合使用,这可以方便地使用functools.lru_cache来完成-仅(较小)绊脚石:参数必须为可散列的。事实证明,代码中的修改使这种强制(将rgb数组投射到3字节的字符串)恰好可以提高性能。

仅当您拥有覆盖大多数色相的大量像素时,才需要使用完整的查找表。在这种情况下,最快的方法是使用numpy花式索引进行实际查找。

import numpy as np
import time
import functools

def rgb2something(rgb):
    # waste some time:
    np.exp(0.1*rgb)
    return rgb.mean()

@functools.lru_cache(None)
def rgb2something_lru(rgb):
    rgb = np.frombuffer(rgb, np.uint8)
    # waste some time:
    np.exp(0.1*rgb)
    return rgb.mean()

def apply_to_img(img):
    shp = img.shape
    return np.reshape([rgb2something(x) for x in img.reshape(-1, shp[-1])], shp[:2])

def apply_to_img_lru(img):
    shp = img.shape
    return np.reshape([rgb2something_lru(x) for x in img.ravel().view('S3')], shp[:2])

def apply_to_img_smart(img, print_stats=True):
    shp = img.shape
    unq, bck = np.unique(img.reshape(-1, shp[-1]), return_inverse=True, axis=0)
    if print_stats:
        print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
    return np.array([rgb2something(x) for x in unq])[bck].reshape(shp[:2])

def apply_to_img_smarter(img, print_stats=True):
    shp = img.shape
    unq, bck = np.unique(img.ravel().view('S3'), return_inverse=True)
    if print_stats:
        print('total no pixels', shp[0]*shp[1], '\nno unique pixels', len(unq))
    return np.array([rgb2something_lru(x) for x in unq])[bck].reshape(shp[:2])

def make_full_lut():
    x = np.empty((3,), np.uint8)
    return np.reshape([rgb2something(x) for x[0] in range(256)
                       for x[1] in range(256) for x[2] in range(256)],
                      (256, 256, 256))

def make_full_lut_cheat(): # for quicker testing lookup
    i, j, k = np.ogrid[:256, :256, :256]
    return (i + j + k) / 3

def apply_to_img_full_lut(img, lut):
    return lut[(*np.moveaxis(img, 2, 0),)]

from scipy.misc import face

t0 = time.perf_counter()
bw = apply_to_img(face())
t1 = time.perf_counter()
print('naive                 ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru first time        ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_lru(face())
t1 = time.perf_counter()
print('lru second time       ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_smart(face(), False)
t1 = time.perf_counter()
print('using unique:         ', t1-t0, 'seconds')

rgb2something_lru.cache_clear()

t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru first: ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_smarter(face(), False)
t1 = time.perf_counter()
print('unique and lru second:', t1-t0, 'seconds')

t0 = time.perf_counter()
lut = make_full_lut_cheat()
t1 = time.perf_counter()
print('creating full lut:    ', t1-t0, 'seconds')

t0 = time.perf_counter()
bw = apply_to_img_full_lut(face(), lut)
t1 = time.perf_counter()
print('using full lut:       ', t1-t0, 'seconds')

print()
apply_to_img_smart(face())

import Image
Image.fromarray(bw.astype(np.uint8)).save('bw.png')

样品运行:

naive                  6.8886632949870545 seconds
lru first time         1.7458112589956727 seconds
lru second time        0.4085628940083552 seconds
using unique:          2.0951434450107627 seconds
unique and lru first:  2.0168916099937633 seconds
unique and lru second: 0.3118703299842309 seconds
creating full lut:     151.17599205300212 seconds
using full lut:        0.12164952099556103 seconds

total no pixels 786432 
no unique pixels 134105

答案 1 :(得分:1)

首先,请添加Consts函数中的rgb2something,因为这将有助于我们了解该函数的确切作用。

加快速度的最佳方法是对操作进行矢量化。

1)没有缓存

不需要为此操作构造查找表。如果您有一个应用于每个(r, g, b)向量的函数,则可以简单地使用np.apply_along_axis将其应用于图像中的每个向量。在下面的示例中,我假设将rgb2something作为占位符的简单定义-当然可以用您的定义替换此函数。

def rgb2something(vector):
    return sum(vector)

image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
transform = np.apply_along_axis(rgb2something, -1, image)

这将使用image数组,并将函数rgb2something沿轴-1(这是最后一个通道轴)应用于每个一维切片。

2)懒惰的查找表

虽然不需要缓存,但是在某些特定的使用案例中,它将使您受益匪浅。也许您想在成千上万的图像上执行rgb2something的逐像素操作,并且怀疑在图像上会重复许多像素值。在这种情况下,构造查找表可以显着提高性能。我建议懒洋洋地填满表格(我假设您的数据集所覆盖的图像有些相似-具有相似的对象,纹理等,这意味着它们总共只覆盖了整个2的较小子集^ 24个搜索空间)。如果您认为它们涵盖了一个相对较大的子集,则可以事先构造整个查找表(请参阅下一节)。

lut = [-1] * (256 ** 3)

def actual_rgb2something(vector):
    return sum(vector)

def rgb2something(vector):
    value = lut[vector[0] + vector[1] * 256 + vector[2] * 65536]

    if value == -1:
        value = actual_rgb2something(vector)
        lut[vector[0] + vector[1] * 256 + vector[2] * 65536] = value

    return value

然后您可以像以前一样变换每个图像:

image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
transform = np.apply_along_axis(rgb2something, -1, image)

3)预先计算的缓存

也许您的图像足够多样化,可以涵盖整个搜索范围的大集合,并且可以通过减少查找成本来摊销整个缓存的构建成本。

from itertools import product

lut = [-1] * (256 ** 3)

def actual_rgb2something(vector):
    return sum(vector)

def fill(vector):
    value = actual_rgb2something(vector)
    lut[vector[0] + vector[1] * 256 + vector[2] * 65536] = value

# Fill the table
total = list(product(range(256), repeat=3))
np.apply_along_axis(fill, arr=total, axis=1)

现在,您无需再计算值,只需从表中查找它们即可:

def rgb2something(vector):
    return lut[vector[0] + vector[1] * 256 + vector[2] * 65536]

转换图像当然与之前相同:

image = np.random.randint(0, 256, size=(100, 100, 3), dtype=np.uint8)
transform = np.apply_along_axis(rgb2something, -1, image)