C ++ - 有效地计算矢量矩阵乘积

时间:2016-02-24 14:21:40

标签: c++ matrix vector

我需要尽可能高效地计算产品矢量矩阵。具体来说,给定向量s和矩阵A,我需要计算s * A。我有一个Vector课程,其中包含std::vector和一个课程Matrix,它也包含std::vector(效率)。

天真的方法(我现在正在使用的方法)是有类似

的东西
Vector<T> timesMatrix(Matrix<T>& matrix)
{
    Vector<unsigned int> result(matrix.columns());
    // constructor that does a resize on the underlying std::vector

    for(unsigned int i = 0 ; i < vector.size() ; ++i)
    {
        for(unsigned int j = 0 ; j < matrix.columns() ; ++j)
        {
            result[j] += (vector[i] * matrix.getElementAt(i, j));
            // getElementAt accesses the appropriate entry
            // of the underlying std::vector
        }
    }
    return result;
}

工作正常,需要近12000微秒。请注意,向量s包含499个元素,而A499 x 15500

下一步是尝试并行化计算:如果我有N个线程,那么我可以给每个线程一个向量s的一部分和矩阵的{“对应”行{{1 }}。每个线程将计算一个499大小的A,最终结果将是它们的入口总和 首先,在类Vector中,我添加了一个方法来从Matrix中提取一些行并构建一个较小的行:

Matrix

然后我定义了一个线程例程

Matrix<T> extractSomeRows(unsigned int start, unsigned int end)
{
    unsigned int rowsToExtract = end - start + 1;
    std::vector<T> tmp;
    tmp.reserve(rowsToExtract * numColumns);
    for(unsigned int i = start * numColumns ; i < (end+1) * numColumns ; ++i)
    {
        tmp.push_back(matrix[i]);
    }
    return Matrix<T>(rowsToExtract, numColumns, tmp);
}

最后我修改了上面显示的void timesMatrixThreadRoutine (Matrix<T>& matrix, unsigned int start, unsigned int end, Vector<T>& newRow) { // newRow is supposed to contain the partial result // computed by a thread newRow.resize(matrix.columns()); for(unsigned int i = start ; i < end + 1 ; ++i) { for(unsigned int j = 0 ; j < matrix.columns() ; ++j) { newRow[j] += vector[i] * matrix.getElementAt(i - start, j); } } } 方法的代码:

timesMatrix

它仍然有效,但现在需要将近30000微秒,这几乎是以前的三倍。

我做错了吗?你认为有更好的方法吗?

编辑 - 使用“轻量级”Vector<T> timesMatrix(Matrix<T>& matrix) { static const unsigned int NUM_THREADS = 4; unsigned int matRows = matrix.rows(); unsigned int matColumns = matrix.columns(); unsigned int rowsEachThread = vector.size()/NUM_THREADS; std::thread threads[NUM_THREADS]; Vector<T> tmp[NUM_THREADS]; unsigned int start, end; // all but the last thread for(unsigned int i = 0 ; i < NUM_THREADS - 1 ; ++i) { start = i*rowsEachThread; end = (i+1)*rowsEachThread - 1; threads[i] = std::thread(&Vector<T>::timesMatrixThreadRoutine, this, matrix.extractSomeRows(start, end), start, end, std::ref(tmp[i])); } // last thread start = (NUM_THREADS-1)*rowsEachThread; end = matRows - 1; threads[NUM_THREADS - 1] = std::thread(&Vector<T>::timesMatrixThreadRoutine, this, matrix.extractSomeRows(start, end), start, end, std::ref(tmp[NUM_THREADS-1])); for(unsigned int i = 0 ; i < NUM_THREADS ; ++i) { threads[i].join(); } Vector<unsigned int> result(matColumns); for(unsigned int i = 0 ; i < NUM_THREADS ; ++i) { result = result + tmp[i]; // the operator+ is overloaded } return result; }

根据Ilya Ovodov的建议,我定义了一个包含VirtualMatrix的类VirtualMatrix,它在构造函数中初始化为

T* matrixData

然后有一种方法来检索矩阵的特定条目:

VirtualMatrix(Matrix<T>& m)
{
    numRows = m.rows();
    numColumns = m.columns();
    matrixData = m.pointerToData();
    // pointerToData() returns underlyingVector.data();
}

现在执行时间更好(大约8000微秒),但也许有一些改进。特别是线程例程现在是

inline T getElementAt(unsigned int row, unsigned int column)
{
    return *(matrixData + row*numColumns + column);
}

并且真正缓慢的部分是具有嵌套void timesMatrixThreadRoutine (VirtualMatrix<T>& matrix, unsigned int startRow, unsigned int endRow, Vector<T>& newRow) { unsigned int matColumns = matrix.columns(); newRow.resize(matColumns); for(unsigned int i = startRow ; i < endRow + 1 ; ++i) { for(unsigned int j = 0 ; j < matColumns ; ++j) { newRow[j] += (vector[i] * matrix.getElementAt(i, j)); } } } 循环的部分。如果我删除它,结果显然是错误的,但在不到500微秒内“计算”。这就是说现在传递参数几乎没有时间,重要的部分实际上是计算。

根据你的意见,有没有办法让它更快?

2 个答案:

答案 0 :(得分:2)

实际上,您为extractSomeRows中的每个线程制作了矩阵的部分副本。这需要很多时间。 重新设计它,使“某些行”成为虚拟矩阵,指向位于原始矩阵中的数据。

答案 1 :(得分:1)

使用矢量化汇编指令用于体系结构,使其更明确,您希望以4为乘,即x86-64 SSE2 +和可能ARM'S NEON。

如果你明确地在偶然元素中进行操作,C ++编译器通常可以将循环展开为向量化代码:

Simple and fast matrix-vector multiplication in C / C++

还可以选择使用专门用于矩阵乘法的库。对于较大的矩阵,使用基于快速傅立叶变换的特殊实现,Strassen算法等替代算法可能更有效。事实上,最好的办法是使用这样的C库,然后将其包装在看起来类似于C ++向量的接口。