子表达式在表达式模板中的嵌套

时间:2019-06-16 14:15:56

标签: c++ c++11 crtp perfect-forwarding expression-templates

我们正在编写一个表达式模板库,以处理带有稀疏梯度矢量(一阶自动微分)的值的运算。我试图弄清楚如何使引用或值嵌套子表达式成为可能,具体取决于表达式是否是临时的。

我们有一个Scalar类,它包含一个值和一个稀疏的梯度向量。我们使用表达式模板(例如Eigen)来防止构造和分配太多临时Scalar对象。因此,我们有了Scalar(CRTP)继承的类ScalarBase<Scalar>

类型ScalarBase< Left >ScalarBase< Right >的对象之间的二进制操作(例如+,*)返回一个ScalarBinaryOp<Left, Right,BinaryOp>对象,该对象继承自ScalarBase< ScalarBinaryOp<Left, Right,BinaryOp> >

template< typename Left, typename Right >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right )
{
    return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( static_cast< const Left& >( left ),
        static_cast< const Right& >( right ), BinaryAdditionOp{} );
}

ScalarBinaryOp必须包含一个值或对LeftRight类型的操作数对象的引用。持有人的类型由RefTypeSelector< Expression >::Type的模板专业化定义。

当前,这始终是const引用。目前,它适用于我们的测试用例,但是对临时子表达式的引用似乎并不正确或不安全。

显然,我们也不希望复制包含稀疏梯度矢量的Scalar对象。如果xyScalar,则表达式x+y应包含对xy的常量引用。但是,如果f是从ScalarScalar的函数,则x+f(y)应该保存对x的常量引用和f(y)的值。 / p>

因此,我想传递有关子表达式是否为临时的信息。我可以将其添加到表达式类型参数中:

ScalarBinaryOp< typename Left, typename Right, typename BinaryOp , bool LeftIsTemporary, bool RightIsTemporary >

并转到RefTypeSelector

RefTypeSelector< Expression, ExpressionIsTemporary >::Type

但是接下来我需要为每个二进制运算符定义4种方法:

ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, false > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, true > operator+(
    const ScalarBase< Left >& left, ScalarBase< Right >&& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, false > operator+(
    ScalarBase< Left >&& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, true > operator+(
    ScalarBase< Left >&& left, ScalarBase< Right >&& right )

我希望能够通过完美的转发来实现这一目标。但是,我不知道如何在这里实现这一目标。首先,我不能使用简单的“通用引用”,因为它们几乎匹配所有内容。我猜可能将通用引用和SFINAE组合在一起以仅允许某些参数类型是可能的,但我不确定这是要走的路。另外,我想知道是否可以在类型为ScalarBinaryOp的Left和Right类型中编码有关Left和Right最初是左值还是右值引用的信息,而不是使用2个附加的booleans参数以及如何检索该信息。信息。

我必须支持主要符合c ++ 11的gcc 4.8.5。

1 个答案:

答案 0 :(得分:1)

您可以将左值/右值信息编码为LeftRight类型。例如:

ScalarBinaryOp<Left&&, Right&&> operator+(
    ScalarBase<Left>&& left, ScalarBase<Right>&& right)
{
    return ...;
}

ScalarBinaryOp是这样的:

template<class L, class R>
struct ScalarBinaryOp
{
    using Left = std::remove_reference_t<L>;
    using Right = std::remove_reference_t<R>;

    using My_left = std::conditional_t<
        std::is_rvalue_reference_v<L>, Left, const Left&>;
    using My_right = std::conditional_t<
        std::is_rvalue_reference_v<R>, Left, const Right&>;

    ...

    My_left left_;
    My_right right_;
};

或者,您可以明确显示并按值存储所有内容,但Scalar除外。为了能够按值存储Scalar,请使用包装器类:

x + Value_wrapper(f(y))

包装很简单:

struct Value_wrapper : Base<Value_wrapper> {
    Value_wrapper(Scalar&& scalar) : scalar_(std::move(scalar)) {}

    operator Scalar() const {
        return std::move(scalar_);
    }

    Scalar&& scalar_;
};

RefTypeSelector具有Value_wrapper的专门化:

template<> struct RefTypeSelector<Value_wrapper> {
    using Type = Scalar;
};

二进制运算符定义保持不变:

template<class Left, class Right>
ScalarBinaryOp<Left, Right> operator+(const Base<Left>& left, const Base<Right>& right) {
    return {static_cast<const Left&>(left), static_cast<const Right&>(right)};
}

完整示例:https://godbolt.org/z/sJ3NfG

(我在上面使用了一些C ++ 17功能只是为了简化表示法。)