scala DSLs and typed operators: idiomatic implementation?

时间:2016-04-04 17:34:10

标签: scala operators dsl

Consider:

val a:ExpressionT[Int] = Const(1)
val b:ExpressionT[Float] = Const(2.0f)

// + must raise types:
//    (i+i) -> i
//    (i+f) -> f
//    (f+i) -> f
//    (f+f) -> f
val c:ExpressionT[Float] = a+b

where:

// A typed Expression
trait ExpressionT[T] extends Expression{
   def evaluate(): T
}

class Const[T<:Any](value:T) extends ExpressionT[T]  {
   def evaluate(): T = value
}

object Const {
  def apply(value:Int) = new Const[Int](value) 
  def apply(value:Float) = new Const[Float](value)
}

What's the most elegant way to implement the + operator that will work for all numeric types? Is there a way to avoid enumerating all combinations of types?

In C++ this we can simply (from memory):

template <typename T>
class ExpressionT {
    typedef typename T EvalT; 
    virtual T evaluate() const = 0;
}

// HighestType<X,Y>::EvalT,       
// (int,float) --> float
// (float,int) --> float
// (int,int) --> int
template <typename X, typename Y>
class Add : public ExpressionT<  HighestType<X::EvalT, Y::EvalT>::EvalT > {

    Add(Expression<X> const& l, Expression<X> const& r) {...}
    inline ResultT evaluate() { return l+r; }
}

template <typename X, typename Y>
auto operator + ( Expression<T> const& lhs, 
                  Expression<T> const& rhs ){
    return Add(lhs, rhs);
}

... which requires no per-type cases, will also happily fail to compile if +(lhs,rhs) is not defined, and also straightforwardly permits extension to new unseen types +(int, Matrix), +(Matrix, float), without any additional work.

1 个答案:

答案 0 :(得分:1)

You could do something like this:

  implicit class ExpressionCalc[T](val it: ExpressionT[T])(implicit numeric: Numeric[T]) {
    def +(that: ExpressionT[T]) =
      new Const(numeric.plus(it.evaluate(), that.evaluate()))
  }  

 println(Const(5) + Const(6))
 println(Const(5.6f) + Const(6))

Unfortunately it does not fully cover what you would want to do. In your case you want to return Float whenever at least one of the arguments is Float.

But you can try https://github.com/saddle/saddle/, it allows you to do what you want:

 object Const {
    def apply(value:Int) = new Const[Int](value)
    def apply(value:Double) = new Const[Double](value)
  }

  implicit class ExpressionCalc[This](val it: ExpressionT[This]) {
    def +[That, R](that: ExpressionT[That])(implicit op: BinOp[Add, This, That, R]): ExpressionT[R] = {
      new Const(op(it.evaluate(), that.evaluate()))
    }
  }

  println(Const(5) + Const(6))
  println(Const(5.6) + Const(6))
  println(Const(5) + Const(6.7))

Note that saddle does not support Float, so I've changed it to Double. You could also try to look how saddle implemented it's BinOp, and use the same approach in your code