混合三元运算和特征数组

时间:2017-06-10 23:49:18

标签: c++ eigen ternary-operator

我在Eigen docs(http://eigen.tuxfamily.org/index.php?title=Pit_Falls#Ternary_operator)中读到Eigen不能很好地处理三元运算;这当然是我的经历。

我要做的是根据几个布尔标志构建一个数组,下面是我的代码段中的use_XXX标志。我知道至少有一个标志在之前的检查中是真的,但是我无法编译这个块。以下是我尝试过的其他选项:

  1. 使用比特掩码之类的东西为umat构造2 ^ 4 = 16个逻辑选项 - 代码最终冗长且难以维护;呸...

  2. umat初始化为零,然后循环遍历条件,进行原地减法 - 当我手动注释掉术语时,这比单一总和慢很多

  3. 另一个想法是尝试将表达式乘以旗帜,希望Eigen会使用其模板魔法来弄清楚要做什么,但这也不起作用,因为在我的情况下我不会&# 39;如果我不使用它,则初始化数组(此循环中非常高性能的代码)

    umat = (
        (use_gauss_delta ? -coeffs.eta*delta_minus_epsilon.square() : 0)
        +
        (use_delta_ld ? -coeffs.cd*delta_to_ld : 0)
        +
        (use_gauss_tau ? -coeffs.beta*tau_minus_gamma.square() : 0)
        +
        (use_tau_lt ? -coeffs.ct*tau_to_lt : 0)
        )
    );
    

    修改

    我也尝试了select功能,但这很慢。每个mask_XXX都是Eigen::ArrayXi,其他所有都是Eigen::ArrayXd

    umat = (
        mask_gauss_delta.select(-coeffs.eta*delta_minus_epsilon.square(),0)
        +
        mask_delta_ld.select(-coeffs.cd*delta_to_ld,0)
        +
        mask_gauss_tau.select(-coeffs.beta*tau_minus_gamma.square(),0)
        +
        mask_tau_lt.select(-coeffs.ct*tau_to_lt,0)
    );
    

1 个答案:

答案 0 :(得分:1)

您可以通过将ArrayXd添加到一个条件,将类型(如您在问题中包含的链接中所述)强制为.eval()(或您正在使用的任何其他对象)。见下面的例子:

#include <Eigen/Core>
#include <iostream>

using Eigen::ArrayXd;
int main(int argc, char** argv)
{
    ArrayXd aa, res;
    int size = 6;
    aa.setLinSpaced(size, 0, 5);
    double d = 345.5;

    res = (true ? (d * aa.square()).eval() : ArrayXd::Zero(size));
    std::cout << res << std::endl;
    res = (false ? (d * aa.square()).eval() : ArrayXd::Zero(size));
    std::cout << res << std::endl;


    return 0;
}

d * aa.square()CwiseBinaryOp,其中ArrayXd::Zero(size)CwiseNullaryOp,其中任何一个都不能投射到另一个。将.eval()添加到一个使其成为ArrayXd(并将创建一个临时对象,您似乎不需要它)并使三元操作正常工作。然而,

whatever = 
(true  ? (d * aa.square()).eval() : ArrayXd::Zero(size)) + 
(false ? (d * aa.square()).eval() : ArrayXd::Zero(size));

仍然会导致ArrayXd::Zero(size)被评估为临时性能降低。可能性能最佳的选项是

if(use_gauss_delta) umat += -coeffs.eta*delta_minus_epsilon.square();
if(use_delta_ld)    umat += -coeffs.cd*delta_to_ld;
if(use_gauss_tau)   umat += -coeffs.beta*tau_minus_gamma.square();
if(use_tau_lt)      umat += -coeffs.ct*tau_to_lt;

主要的缺点是评估会发生多达四次,但如果没有构建你提到的2 ^ 4选项,我想不出办法避免这种情况。