具有已知矩阵的模板矩阵表达式

时间:2012-04-04 03:20:08

标签: c++ templates matrix

这是一个思考练习,不是特别的问题,但我想听听你的意见。假设我有一些使用模板的矩阵表达式DSL(Eigen,ublas等)。

现在假设我有一些常量矩阵,例如:

Matrix2 sigma1 = {{0,1}, {1,0}};
Matrix2 sigma2 = {{0,i}, {-i,0}};
... etc

...我对那些涉及运行时值的矩阵有一些操作:

a*sigma1 + b*sigma2; // a,b runtime

您有什么想法来实现常量矩阵,以便编译器可以最大程度地优化表达式?特别是,如何将(i,j)运算符解析为常量?

2 个答案:

答案 0 :(得分:4)

根据我对问题空间的理解:给定domain-specific language,我们想要确定最小变换,使得某些运算符超过数据(例如,(i,j)),导致查找常数而不是数学公式的计算(例如,a*sigma1 + b*sigma2)。

让我们来探讨一些可能性:

  • 直接执行数学运算

    0th-level实施。如果您的编译器没有明确的优化,如果我们直接执行指令会发生什么?

    答案取决于它。但是,在大多数处理器上,您的代码执行将落在the CPU cache上,您的程序集和分支执行将根据您内核的最佳能力进行优化。该过程的内容是actually really cool,但我们假设您希望超越这些功能并直接在代码中解决操作的限制。

  • 使用compiler-compiler

    绑定并捕获空间

    一阶优化是使用compiler-compiler绑定并捕获可能的输入和输出空间。虽然这将有效地解决您的输入范围,仅将其限制为您想要的输入和输出集合,否则它 。所以,我们必须继续前进。

  • StringificationMacro Expansion

    二阶优化将直接执行值空间的字符串或宏扩展。虽然这是fraught with corner-cases and surprising implementation-level tar pits,但如果需要操作,可以由编译器直接完成。 (另见:loop unwinding

  • closed-form expression的手动推导和堆栈约束可满足性
    (使用例如查找表)

    最后,我们的三阶优化将直接限制您的空间。这要求您有一个明确定义的closed-form relation,其有限的输入和输出空间可以有效地工作。如果无法确定此关系或没有界限,you're out of luck并且如果不知道存在更好的关系,则应考虑保留当前的实现。

在这些优化技术中,考虑到您所描述的问题边界,最适用于linear algebraic operations的是后两种。由于大多数操作(例如矩阵平移,旋转和缩放操作)本质上都是确定性的,因此您确实可以有效地优化和限制空间。

要获得更理论的答案,建议您咨询http://cs.stackexchange.comhttp://cstheory.stackexchange.comhttp://math.stackexchange.com。两者都有许多线程专门用于decidability和封闭形式的证明,polynomial solutions用于整个方程组。

答案 1 :(得分:2)

好的,这很可怕,但与我对原帖的评论有关。

使用这种结构应该可以定义您需要的相关操作,但编写所有适当的特殊化将需要做很多工作。您可能还希望交换行/列。

最后定义矩阵当然不像原帖那样优雅,但也许可以改进,特别是在C ++ 11中使用'auto'。

//-----------------------------------------------------------------------------
struct Unused {};

struct Imaginary {
    Imaginary() {}
    Imaginary(Unused const& unused) {}
};
struct MinusImaginary {
    MinusImaginary() {}
    MinusImaginary(Unused const& unused) {}
};

//-----------------------------------------------------------------------------
template <int I, int F = 0>
struct Fixed {
    Fixed() {}
    Fixed(Unused const& unused) {}
};

//-----------------------------------------------------------------------------
struct Float
{
    Float(float value) : value_(value) {}
    const float value_;
};

//-----------------------------------------------------------------------------
template <typename COL0, typename COL1>
struct Vector2
{
    typedef COL0 col0_t;
    typedef COL1 col1_t;

    template <typename T0, typename T1>
    Vector2(T0 const& t0, T1 const& t1)
        : col0_(t0)
        , col1_(t1)
    {}

    COL0 col0_;
    COL1 col1_;
};

//-----------------------------------------------------------------------------
template <typename ROW0, typename ROW1>
struct Matrix2
{
    typedef ROW0 row0_t;
    typedef ROW1 row1_t;

    Matrix2()
        : row0_(Unused(), Unused())
        , row1_(Unused(), Unused())
    {
    }
    template <typename M00, typename M01, typename M10, typename M11>
    Matrix2(M00 const& m00, M01 const& m01, M10 const& m10, M11 const& m11)
        : row0_(m00, m01)
        , row1_(m10, m11)
    {
    }

    ROW0 row0_;
    ROW1 row1_;
};

//-----------------------------------------------------------------------------
Matrix2<
    Vector2< Fixed<0>, Fixed<1> >,
    Vector2< Fixed<1>, Fixed<0> > 
> sigma1;

const float f = 0.1f;

//-----------------------------------------------------------------------------
Matrix2<
    Vector2< Fixed<0>, Imaginary >,
    Vector2< MinusImaginary, Fixed<0> > 
> sigma2;

//-----------------------------------------------------------------------------
Matrix2<
    Vector2< Fixed<0>, Float >,
    Vector2< Float, Fixed<0> > 
> m3(Unused(), 0.2f,
     0.8f, Unused());


// EDIT: Nicer initialization syntax in c++11

//-----------------------------------------------------------------------------
template <typename M00, typename M01, typename M10, typename M11>
Matrix2< Vector2<M00, M01>, Vector2<M10, M11> >
MakeMatrix(M00 const& m00, M01 const& m01, M10 const& m10, M11 const& m11)
{
    return Matrix2< Vector2<M00, M01>, Vector2<M10, M11> >(m00,m01,m10,m11);
}

auto m4 = MakeMatrix(Fixed<0>(),  Float(0.2f), 
                     Float(0.8f), Fixed<0>()  );