运算符+类中的错误Matrix修改的稀疏行压缩

时间:2017-11-18 12:30:43

标签: c++ matrix sparse-matrix

我实现了一个修改后的压缩稀疏行矩阵类,operator+(matrix, matrix)的问题在所有情况下都不起作用,但我没有发现错误!这里是矩阵类接口:

template <typename data_type>
class MCSRmatrix {
public: 
template <typename T>
          friend std::ostream& operator<<(std::ostream& os ,const MCSRmatrix<T>& m) noexcept ;


template <typename T>

 friend MCSRmatrix<T> operator+(const MCSRmatrix<T>& m1, const MCSRmatrix<T>& m2 ) ;


 constexpr MCSRmatrix(std::initializer_list<std::initializer_list<data_type>> rows);
            auto constexpr printMCSR() const noexcept ;

std::size_t constexpr findIndex(const itype row ,  const itype col) const noexcept ;
    const data_type operator()(const itype r , const itype c) const noexcept ;

          data_type operator()(const itype r , const itype c) noexcept ;


        private:

             std::vector<data_type> aa_ ;
             std::vector<itype>     ja_ ; 

             int dim ; 
        };

并且运算符+写为:

template <typename T>
MCSRmatrix<T> operator+(const MCSRmatrix<T>& m1, const MCSRmatrix<T>& m2 )
{
      if(m1.dim != m2.dim)
      {
          throw InvalidSizeExceptm1.ja_.at(m1.ja_.at(i)-1 + j1 )ion("Matrixs dimension does match! Error in operator +");  
      }
      else
      {
         MCSRmatrix<T> res(m1.dim);



         for(auto i=0; i < res.dim ; i++)
            res.aa_.at(i) = m1.aa_.at(i)  + m2.aa_.at(i) ; 
         res.ja_.at(0) = res.dim+2;

         std::set<unsigned int> ctrl; 
         std::set<unsigned int> index ;   

         int n1=0, n2=0, j1=0 , j2 =0;
         for(auto i=0 , numElemRow=0 ; i < res.dim  ; i++)
         {
           index.clear(); 
           ctrl.clear();

           n1 = m1.ja_.at(i+1)- m1.ja_.at(i) ;
           n2 = m2.ja_.at(i+1)- m2.ja_.at(i) ;

           j1=0 , j2=0 ;
           auto sum1 = 0. , sum2 = 0. , sum=0.;


            for(auto j = 0; j < std::max(n1,n2) ; j++ )
            {
               if(n1 && n2) 
               {
                 if(m1.ja_.at(m1.ja_.at(i)-1 + j1 ) == m2.ja_.at(m2.ja_.at(i)-1 + j2)) 
                 {
                     ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));

                     index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));

                     sum = m1.aa_.at(m1.ja_.at(i)-1 + j1 ) + m2.aa_.at(m2.ja_.at(i)-1 + j2) ;


                 }
                 else if(m1.ja_.at(m1.dim+1 + j1 ) != m2.ja_.at(m2.dim+1 + j2))
                 {   
                     ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                     ctrl.insert(m2.ja_.at(m1.ja_.at(i)-1 + j2 ));

                     index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                     index.insert(m2.ja_.at(m2.ja_.at(i)-1 + j2 ));

                     sum1 = m1.aa_.at(m1.ja_.at(i)-1 + j1);               
                     sum2 = m2.aa_.at(m2.ja_.at(i)-1 + j1);     
                 }
               }        
               else if(n1)
               {  
                  ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                  index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                  sum1 = m1.aa_.at(m1.ja_.at(i)-1 + j1);               
               }
               else if(n2)
               {
                  ctrl.insert(m2.ja_.at(m2.ja_.at(i)-1 + j1 ));
                  index.insert(m2.ja_.at(m2.ja_.at(i)-1 + j1 ));
                  sum2 = m2.aa_.at(m2.ja_.at(i)-1 + j1);     
               }
           //    std::cout << "sum   " << sum  << '\n';    
           //    std::cout << "sum1  " << sum1 << '\n';    
           //    std::cout << "sum2  " << sum2 << '\n';    
               if(sum1)
                  res.aa_.push_back(sum1);
               else if(sum2)
                  res.aa_.push_back(sum2);
               else if(sum)
                  res.aa_.push_back(sum);

               if(j1 < n1) j1++ ;
               if(j2 < n2) j2++ ;
             } 


             //std::cout << res.ja_.size() << '\n'; 
            res.ja_.at(i+1) = res.ja_.at(i) + ctrl.size() ;  
            for(auto& x : index )
               res.ja_.push_back(x);


         }
         return res ;   
      }
}

构造函数是:

template <typename T>
constexpr MCSRmatrix<T>::MCSRmatrix( std::initializer_list<std::initializer_list<T>> rows)
{
      this->dim  = rows.size();
      auto _rows = *(rows.begin());

      aa_.resize(dim+1);
      ja_.resize(dim+1);

      if(dim != _rows.size())
      {
          throw InvalidSizeException("Error in costructor! MCSR format require square matrix!");  
      }

      itype w = 0 ;
      ja_.at(w) = dim+2 ;
      for(auto ii = rows.begin(), i=1; ii != rows.end() ; ++ii, i++)
      {
          for(auto ij = ii->begin(), j=1, elemCount = 0 ; ij != ii->end() ; ++ij, j++ )   
          {
              if(i==j)
                 aa_[i-1] = *ij ;
              else if( i != j && *ij != 0 )
              {   
                 ja_.push_back(j); 
                 aa_.push_back(*ij); 
                 elemCount++ ;
              }
              ja_[i] = ja_[i-1] + elemCount;           
          }
      }     
      printMCSR();
}

主程序和整个班级在这里链接ModCSRmatrix.Hmain.cpp你可以在主程序的最后一个案例中看到结果是错误的

好的我认为我找到并纠正了这个错误!! (仍在寻找未来的一个) 运算符+被重写如下:

template <typename T>
MCSRmatrix<T> operator+(const MCSRmatrix<T>& m1, const MCSRmatrix<T>& m2 )
{
      if(m1.dim != m2.dim)
      {
          throw InvalidSizeException("Matrixs dimension does match! Error in operator +");  
      }
      else
      {
         MCSRmatrix<T> res(m1.dim);



         for(auto i=0; i < res.dim ; i++)
            res.aa_.at(i) = m1.aa_.at(i)  + m2.aa_.at(i) ; 
         res.ja_.at(0) = res.dim+2;

         std::set<unsigned int> ctrl; 
         std::set<unsigned int> index ;   

         int n1=0, n2=0, j1=0 , j2 =0;
         for(auto i=0 , numElemRow=0 ; i < res.dim  ; i++)
         {
           index.clear(); 
           ctrl.clear();

           n1 = m1.ja_.at(i+1)- m1.ja_.at(i) ;
           n2 = m2.ja_.at(i+1)- m2.ja_.at(i) ;

           j1=0 , j2=0 ;
           auto sum1 = 0. , sum2 = 0. , sum=0.;


            for(auto j = 0; j < std::max(n1,n2) ; j++ )
            {
               if(n1 && n2) 
               {
                 if(m1.ja_.at(m1.ja_.at(i)-1 + j1 ) == m2.ja_.at(m2.ja_.at(i)-1 + j2)) 
                 {
                     ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));

                     index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));

                     sum = m1.aa_.at(m1.ja_.at(i)-1 + j1 ) + m2.aa_.at(m2.ja_.at(i)-1 + j2) ;


                 }
                 else if(m1.ja_.at(m1.dim+1 + j1 ) != m2.ja_.at(m2.dim+1 + j2))
                 {   
                     ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                     ctrl.insert(m2.ja_.at(m1.ja_.at(i)-1 + j2 ));

                     index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                     index.insert(m2.ja_.at(m2.ja_.at(i)-1 + j2 ));

                     sum1 = m1.aa_.at(m1.ja_.at(i)-1 + j1);               
                     sum2 = m2.aa_.at(m2.ja_.at(i)-1 + j1);     
                 }
               }        
               else if(n1)
               {  
                  ctrl.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                  index.insert(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
                  sum1 = m1.aa_.at(m1.ja_.at(i)-1 + j1);               
               }
               else if(n2)
               {
                  ctrl.insert(m2.ja_.at(m2.ja_.at(i)-1 + j1 ));
                  index.insert(m2.ja_.at(m2.ja_.at(i)-1 + j1 ));
                  sum2 = m2.aa_.at(m2.ja_.at(i)-1 + j1);     
               }
               if(sum1)
               {
                 // std::cout << "sum1  " << sum1 << "indx: " << m1.ja_.at(m1.ja_.at(i)-1 + j1 ) << '\n';    
                  res.aa_.push_back(sum1);
                  res.ja_.push_back(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
               }
               if(sum2)
               {
                 // std::cout << "sum2  " << sum2 << "indx: " << m2.ja_.at(m2.ja_.at(i)-1 + j1 ) << '\n';    
                  res.aa_.push_back(sum2);
                  res.ja_.push_back(m2.ja_.at(m2.ja_.at(i)-1 + j2 ));
               }
               if(sum)
               {    
                  //  std::cout << "sum   " << sum  << "indx: " << m1.ja_.at(m1.ja_.at(i)-1 + j1 ) <<'\n';    
                    res.aa_.push_back(sum);
                    res.ja_.push_back(m1.ja_.at(m1.ja_.at(i)-1 + j1 ));
               }

               if(j1 < n1) j1++ ;
               if(j2 < n2) j2++ ;
             } 


             //std::cout << res.ja_.size() << '\n'; 
            res.ja_.at(i+1) = res.ja_.at(i) + ctrl.size() ;  
            //for(auto& x : index )     <-- here was the problem 
            //   res.ja_.push_back(x);  <<-- the index must push_back at
             // the same time of the above push_back into the vector of
             // value (aa_)



         }
         return res ;   
      }
}

如果你想让我知道的话,如果这不是问题,我会把这个帖子打开另一天。

0 个答案:

没有答案