我试图学习矩阵乘法,并遇到Strassen乘法与标准矩阵乘法的这段代码,因此我尝试实现它。但是,此代码使用了过多的内存,以至于当矩阵足够大时,它会杀死程序。另外,由于它使用过多的内存,因此需要更长的处理时间。
由于我不太了解复杂的内存管理,所以我不太愿意过多地处理代码,我真的很想了解这个主题。
在代码中构建一个cut参数,发现在320使它运行更快,并且似乎在内存管理方面有所改善。
编辑。我实现了一个拷贝构造函数,一个析构函数和一个跟踪内存使用情况的函数,它修复了它所遇到的内存泄漏,但是在1990年到2100年之间的时间有了很大的飞跃对于Strassen矩阵仍然存在。
matrix.h
#ifndef MATRIX_H
#define MATRIX_H
#include <vector>
using namespace std;
class matrix
{
public:
matrix(int dim, bool random, bool strassen);
matrix(const matrix& old_m);
inline int dim() {
return dim_;
}
inline int& operator()(unsigned row, unsigned col) {
return data_[dim_ * row + col];
}
inline int operator()(unsigned row, unsigned col) const {
return data_[dim_ * row + col];
}
void print();
matrix operator+(matrix b);
matrix operator-(matrix b);
~matrix();
private:
int dim_;
int* data_;
};
#endif
Matrix.cpp
#include <iostream>
#include <vector>
#include <stdlib.h>
#include <time.h>
#include "SAMmatrix.h"
using namespace std;
matrix::matrix(int dim, bool random, bool strassen) : dim_(dim) {
if (strassen) {
int dim2 = 2;
while (dim2 < dim)
dim2 *= 2;
dim_ = dim2;
}
data_ = new int[dim_ * dim_];
if (!random) return;
for (int i = 0; i < dim_ * dim_; i++)
data_[i] = rand() % 10;
}
matrix::matrix(const matrix& old_m){
dim_ = old_m.dim_;
data_ = new int[dim_ * dim_];
for (int i = 0; i < dim_ * dim_; i++)
data_[i] = old_m.data_[i];
}
void matrix::print() {
for (int i = 0; i < dim_; i++) {
for (int j = 0; j < dim_; j++)
cout << (*this)(i, j) << " ";
cout << "\n";
}
cout << "\n";
}
matrix matrix::operator+(matrix b) {
matrix c(dim_, false, false);
for (int i = 0; i < dim_; i++)
for (int j = 0; j < dim_; j++)
c(i, j) = (*this)(i, j) + b(i, j);
return c;
}
matrix matrix::operator-(matrix b) {
matrix c(dim_, false, false);
for (int i = 0; i < dim_; i++)
for (int j = 0; j < dim_; j++)
c(i, j) = (*this)(i, j) - b(i, j);
return c;
}
matrix::~matrix()
{
delete [] data_;
}
矩阵主体
#include <iostream>
#include <stdlib.h>
#include <time.h>
#include <sys/time.h>
#include "SAMmatrix.h"
#include "stdlib.h"
#include "stdio.h"
#include "string.h"
typedef pair<matrix, long> result;
int cut = 64;
matrix mult_std(matrix a, matrix b)
{
matrix c(a.dim(), false, false);
for (int i = 0; i < a.dim(); i++)
for (int k = 0; k < a.dim(); k++)
for (int j = 0; j < a.dim(); j++)
c(i, j) += a(i, k) * b(k, j);
return c;
}
matrix get_part(int pi, int pj, matrix m)
{
matrix p(m.dim() / 2, false, true);
pi = pi * p.dim();
pj = pj * p.dim();
for (int i = 0; i < p.dim(); i++)
for (int j = 0; j < p.dim(); j++)
p(i, j) = m(i + pi, j + pj);
return p;
}
void set_part(int pi, int pj, matrix* m, matrix p)
{
pi = pi * p.dim();
pj = pj * p.dim();
for (int i = 0; i < p.dim(); i++)
for (int j = 0; j < p.dim(); j++)
(*m)(i + pi, j + pj) = p(i, j);
}
matrix mult_strassen(matrix a, matrix b)
{
if (a.dim() <= cut)
return mult_std(a, b);
matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);
matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);
matrix m1 = mult_strassen(a11 + a22, b11 + b22);
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);
matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);
return c;
}
pair<matrix, long> run(matrix(*f)(matrix, matrix), matrix a, matrix b)
{
struct timeval start, end;
gettimeofday(&start, NULL);
matrix c = f(a, b);
gettimeofday(&end, NULL);
long e = (end.tv_sec * 1000 + end.tv_usec / 1000);
long s = (start.tv_sec * 1000 + start.tv_usec / 1000);
return pair<matrix, long>(c, e - s);
}
int parseLine(char* line){ /* overflow*/
// This assumes that a digit will be found and the line ends in " Kb".
int i = strlen(line);
const char* p = line;
while (*p <'0' || *p > '9') p++;
line[i-3] = '\0';
i = atoi(p);
return i;
}
int getValue(){ //Note: this value is in KB!
FILE* file = fopen("/proc/self/status", "r");
int result = -1;
char line[128];
while (fgets(line, 128, file) != NULL){
if (strncmp(line, "VmSize:", 7) == 0){
result = parseLine(line);
break;
}
}
fclose(file);
return result;
}
int main()
{
/* test cut of for strassen
/*
for (cut = 2; cut <= 512; cut++) {
matrix a(512, true, true);
matrix b(512, true, true);
result r = run(mult_strassen, a, b);
cout << cut << " " << r.second << "\n";
}
*/
/* performance test: standard and strassen */
/*1024 going up by 64*/
for (int dim = 1500; dim <= 2300; dim += 200)
{
double space = getValue() * .01;
cout << "Space before: " << space << "Mb" << "\n";
matrix a(dim, true, false);
matrix b(dim, true, false);
result std = run(mult_std, a, b);
matrix c(dim, true, true);
matrix d(dim, true, true);
result strassen = run(mult_strassen, c, d);
cout << "Dim " << " Std " << " Stranssen" << endl;
cout << dim << " " << std.second << "ms " << strassen.second << "ms " << "\n";
double spaceA = getValue() * .01;
cout << "Space: " << spaceA << "Mb" << "\n";
cout << " " << endl;
}
}
我将其设置为从1500增至2300,并且该程序在完成之前被“杀死”了
1500 2406 4250
1700 3463 4252
1900 4819 4247
2100 6487 30023
Killed
此外,当尺寸从1900变为2100时,也不应像这样大幅度地跳时间。