如何有效地比较2D矩阵中的每一对行?

时间:2019-05-01 19:25:05

标签: python numpy vectorization numpy-ndarray numpy-broadcasting

我正在处理一个子例程,在该例程中,我需要处理矩阵的每一行,并查找当前行中包含的其他行。为了说明某行何时包含另一行,请考虑以下3x3矩阵:

axios({
  method: "get",
  url: "https://stemuli.blob.core.windows.net/stemuli/mentor-lesson-video-c20665e2-b17c-4f22-9ef8-f6cd62b113dd.mp4",
  responseType: "stream"
}).then(function (response) {
  const videoType = response.headers["content-type"].split("/")[1];

  const file = fs.createWriteStream(
    `./cache/thumbnails/${tempFileName + "."}${videoType}`
  );
  response.data.pipe(file);

  thumbsupply
    .generateThumbnail(`./cache/thumbnails/${tempFileName}`)
    .then(thumb => res.json(thumb))
    .catch(err => {
      res.json({
        Error: "Error creating thumbnail for video"
      });
      console.log(err);
    });
});

第1行包含第3行,因为第1行中的每个元素都大于或等于第3行,但第1行不包含第2行。

我想出了以下解决方案,但是由于for循环(矩阵的大小约为6000x6000),它非常慢。

[[1, 0, 1], 

 [0, 1, 0], 

 [1, 0, 0]]

请问是否有可能更有效地做到这一点?

2 个答案:

答案 0 :(得分:1)

由于矩阵的大小以及问题的要求,我认为迭代是不可避免的。您不能利用广播,因为广播会消耗您的内存,因此您需要逐行对现有阵列进行操作。不过,您可以使用numbanjit通过纯python方法大大加快这一过程。


import numpy as np
from numba import njit


@njit
def zero_out_contained_rows(a):
    """
    Finds rows where all of the elements are
    equal or smaller than all corresponding
    elements of anothe row, and sets all
    values in the row to zero

    Parameters
    ----------
    a: ndarray
      The array to modify

    Returns
    -------
    The modified array

    Examples
    --------
    >>> zero_out_contained_rows(np.array([[1, 0, 1], [0, 1, 0], [1, 0, 0]]))
    array([[1, 0, 1],
            [0, 1, 0],
            [0, 0, 0]])
    """
    x, y = a.shape

    contained = np.zeros(x, dtype=np.bool_)

    for i in range(x):
        for j in range(x):
            if i != j and not contained[j]:
                equal = True
                for k in range(y):
                    if a[i, k] < a[j, k]:
                        equal = False
                        break
                contained[j] = equal

    a[contained] = 0

    return a

这将使您连续了解是否在另一行中使用了一行。这样可以在最终用0清除其他行中包含的行之前,通过短路来防止许多不必要的比较。


与您最初使用迭代进行的尝试相比,这是一个速度上的改进,并且还可以对适当的行进行归零。


a = np.random.randint(0, 2, (6000, 6000))

%timeit zero_out_contained_rows(a)
1.19 s ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

一旦您的尝试完成运行(目前大约10分钟),我就会更新计时。

答案 1 :(得分:0)

如果您的矩阵为6000x6000,则需要(6000 * 6000-6000)/ 2 = 17997000计算。

代替使用np.triu_indices,您可以尝试对矩阵的顶部三角形使用生成器-这样可以减少内存消耗。试试这个,也许会有帮助。

def indxs(lst):
   for i1, el1 in enumerate(lst):
      for el2 in lst[i1:][1:]:
         yield (el1, el2)