构建模板化ODE求解器

时间:2014-04-26 08:46:11

标签: c++ templates metaprogramming

我正在尝试使用以下函数声明实现模板化的ODE解算器:

template<class ODEFunction,class StopCondition=decltype(continue_always<ODEFunction>)> 
bool euler_fwd(ODEFunction& f,typename State<ODEFunction>::type& x_0
    ,double t_0,double dt,size_t N_iter
    ,StopCondition& cond=continue_always<ODEFunction>);

完整来源:

/*From SO answer*/

template<class F>
struct State;

template <class R, class... A> 
struct State<R (*)(A...)>
    {
    typedef R type;
    };

/*End from SO answer*/


/**Default stop condition. Always return 0 to continue operation.
*/
template<class ODEFunction>
bool continue_always(const typename State<ODEFunction>::type& x_0,double t_0)
    {return 0;}

/**Euler forward solver
*/
template<class ODEFunction,class StopCondition=decltype(continue_always<ODEFunction>)> 
bool euler_fwd(ODEFunction& f,typename State<ODEFunction>::type& x_0
    ,double t_0,double dt,size_t N_iter
    ,StopCondition& cond=continue_always<ODEFunction>)
    {
    size_t k=0;
    while(N_iter)
        {
        if(cond(x_0,t_0))
           {return 1;}
        x_0+=dt*f(x_0,k*dt);
        --N_iter;
        ++k;
        }
    return 0;
    }

试图用简单的函数调用euler_fwd

double f(double x,double t)
    {return x;}

省略continue_always谓词,GCC写

  

错误:无效使用不完整类型'struct State'      bool continue_always(const typename State :: type&amp; x_0,double t_0)

...

  

test.cpp:18:47:错误:没有匹配函数来调用'euler_fwd(double(&amp;)(double,double),double&amp;,double&amp;,double&amp;,size_t&amp;)'

编辑:

如果我尝试跳过使用cond的默认参数:

euler_fwd(testfunc,x_0,t_0,dt,N,continue_always<decltype(testfunc)>);

我仍然收到错误

  

test.cpp:18:97:注意:无法解决重载函数'continue_always'的地址

1 个答案:

答案 0 :(得分:0)

不是试图操纵函数类型(fcond的类型),为什么不将状态设为模板参数呢?

#include <cstdlib> // for size_t
#include <functional>

template<class R>
bool continue_always(const R& x_0,double t_0)
{return 0;}

template<class R> 
bool euler_fwd(std::function<R(const R &)> f,R& x_0
  ,double t_0,double dt,size_t N_iter
  ,std::function<bool(const R&, double)> cond=continue_always<R>)
{
  size_t k=0;
  while(N_iter)
  {
    if(cond(x_0,t_0))
      {return 1;}
    x_0+=dt*f(x_0,k*dt);
    --N_iter;
    ++k;
  }
  return 0;
}