如何计算CSR矩阵对角线块中每行的非零

时间:2018-09-02 19:12:08

标签: c++ sparse-matrix

我有一个CSR格式的矩阵,需要一个c ++向量,该向量包含每行非零条目(计数)的数量,该数量限于大小不同的正方形,对角线块。

 // The matrix in CSR format
    std::vector<int> row_idx = {0,2,4,6,10,13}; // size n+1 where 0-n are the idx of the row starts in values and column_idx and n+1 the TOTAL number of values
    std::vector<int> values = {1,6,2,7,3,8,10,11,4,9,12,13,5}; // nonzero matrix values
    std::vector<int> column_idx = {0,3,1,3,2,4,0,1,3,4,2,3,4};  // column indices of the values

下面的示例有两个大小分别为A和B的块(感兴趣的块始终是正方形且在对角线上)。

此示例的预期结果将是nnz_in_ranges [n] = {1,1,2,2,3},但是由于它需要嵌入另一个例程中,因此我主要是在寻找一个例程来使用C ++。像这样:

// block A
int rangeStart = 0;
int rangeEnd = 2;

// block B
//int rangeStart = 2;
//int rangeEnd = n;

for (int i = rangeStart; i<rangeEND; ++i)
{
    nnz_in_ranges[n] = ...
}

// desired result for block A: nnz_in_ranges[n] = {1,1,0,0,0}
// desired result for block B: nnz_in_ranges[n] = {0,0,2,2,3}

我尝试使用std :: count ...函数解决它,但由于无法引入列范围,因此无法扩展下面的代码,计算每行的非零值。

有人知道如何解决此问题吗?

#include <stdio.h>
#include <iostream>
#include <vector>
#include <algorithm>

int main()
{

// NxN matrix example

/*
index     0   1    2   3   4
        ______________________
     0  | 1   0  | 0   6   0 |
        |   A    |           |
     1  | 0   2  | 0   7   0 |
        |--------------------|
     2  | 0   0  | 3   0   8 |
        |        |   B       |
     3  | 10  11 | 0   4   9 |  expected result: nnz_in_ranges[n] = {1,1,2,2,3} 
        |        |           |  here ranges are A and B
     4  | 0   0  |12  13   5 |
        ----------------------

*/


// matrix in CSR format

    int n = 5; // matrix size
    int nnz = 13; // number of nonzero values

    // The matrix in CSR format
    std::vector<int> row_idx = {0,2,4,6,10,13}; // size n+1 where 0-n are the idx of the row starts in values and column_idx and n+1 the TOTAL number of values
    std::vector<int> values = {1,6,2,7,3,8,10,11,4,9,12,13,5}; // nonzero matrix values
    std::vector<int> column_idx = {0,3,1,3,2,4,0,1,3,4,2,3,4};  // column indices of the values

    std::vector<int> tmp = {0,0,0,0,0,0,0,0,0,0,0,0,0};

    std::vector<int> sum(n);

    // count nonzeros per row  sum[] = {2,2,2,4,3}
    for(size_t i = 0; i < row_idx.size()-1; ++i) {
    sum[i] = std::count(tmp.begin() + row_idx[i], tmp.begin() + row_idx[i + 1], 0);
    }


    std::cout << "nnz_in_range = " << std::endl;
    for (int i=0; i<n; i++)
    {
    std::cout << ' ' << sum[i];
    }

return 0;

}

1 个答案:

答案 0 :(得分:0)

稀疏矩阵可以明确包含零个条目。根据您的问题,不清楚是要计算头寸还是实际值。我会假设是前者,因为您的计数代码未使用values

然后,我们只需遵循CSR格式的定义,并获得如下所示的内容:

std::vector<int> count_positions_in_block(int block_begin, int block_size, 
    const std::vector<int>& row_idx, const std::vector<int>& column_idx)
{
    std::vector<int> cnt(block_size, 0);

    const auto block_end = block_begin + block_size;
    assert(block_end < row_idx.size());

    for (auto row = block_begin; row < block_end; ++row)
    {
        auto first = row_idx[row];
        auto last = row_idx[row + 1];
        assert(first <= last);

        for (auto i = first; i < last; ++i)
            if (column_idx[i] >= block_begin && column_idx[i] < block_end)
                ++cnt[row - block_begin];
    }

    return cnt;
}

auto nnz1 = count_positions_in_block(0, 2, row_idx, column_idx);
// nnz1 = [1, 1]
auto nnz2 = count_positions_in_block(2, 3, row_idx, column_idx);
// nnz2 = [2, 2, 3]

可以使用std::count_if重写它:

std::vector<int> count_positions_in_block(int block_begin, int block_size, 
    const std::vector<int>& row_idx, const std::vector<int>& column_idx)
{
    std::vector<int> cnt(block_size, 0);

    const auto block_end = block_begin + block_size;
    assert(block_end < row_idx.size());

    const auto is_in_block = [&](auto col)
        { return (col >= block_begin && col < block_end); };

    for (auto row = block_begin; row < block_end; ++row)
    {
        auto first = row_idx[row];
        auto last = row_idx[row + 1];
        assert(first <= last);

        const auto cib = column_idx.begin();
        cnt[row - block_begin] = std::count_if(
            cib + first, cib + last, is_in_block);
    }

    return cnt;
}