如何使用锈宏简化数学公式?

时间:2019-04-05 20:43:20

标签: rust rust-macros

我必须承认我对宏有点迷路。 我想建立一个执行以下任务的宏, 我不确定该怎么做。我想执行标量产品 长度为N的两个数组(例如x和y)中的一个。 我要计算的结果的形式为:

z = sum_{i=0}^{N-1} x[i] * y[i].

xconst,其中元素是0, 1, or -1,在编译时是已知的, y的元素是在运行时确定的。因为 x的结构,许多计算都是无用的(项乘以0 可以从总和中删除,并且形式1 * y[i], -1 * y[i]的乘法可以分别转换为y[i], -y[i]

例如,如果x = [-1, 1, 0],上面的标量积将是

z=-1 * y[0] + 1 * y[1] + 0 * y[2]

为了加快计算速度,我可以手动展开循环并重写 整个过程没有x[i],我可以将上面的公式硬编码为

z = -y[0] + y[1]

但是此过程并不优雅,容易出错 而且当N变大时非常繁琐。

我很确定我可以使用宏来执行此操作,但是我不知道在哪里 开始(我读过的不同书籍对宏的了解并不深, 我被卡住了......

你们中的任何人是否有办法使用宏(如果可能)来解决这个问题?

提前感谢您的帮助!

编辑:正如许多答案所指出的那样,对于整数,编译器足够聪明,可以删除优化循环。我不仅在使用整数,而且还在使用浮点数(x数组是i32s,但通常yf64 s),所以编译器不够聪明(理应如此)优化循环。以下代码段给出了以下asm。

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
    xorpd   %xmm0, %xmm0
    movsd   (%rdi), %xmm1
    mulsd   %xmm0, %xmm1
    addsd   %xmm0, %xmm1
    addsd   8(%rdi), %xmm1
    subsd   16(%rdi), %xmm1
    movupd  24(%rdi), %xmm2
    xorpd   %xmm3, %xmm3
    mulpd   %xmm2, %xmm3
    addsd   %xmm3, %xmm1
    unpckhpd    %xmm3, %xmm3
    addsd   %xmm1, %xmm3
    addsd   40(%rdi), %xmm3
    mulsd   48(%rdi), %xmm0
    addsd   %xmm3, %xmm0
    subsd   56(%rdi), %xmm0
    retq

4 个答案:

答案 0 :(得分:8)

首先,一个(proc)宏根本无法在数组x内查看。它得到的只是您传递它的令牌,没有任何上下文。如果您想让它知道值(0,1,-1),则需要将这些值直接传递给您的宏:

let result = your_macro!(y, -1, 0, 1, -1);

但是您实际上并不需要宏。编译器做了很多优化,如其他答案所示。但是,正如您在编辑中已经提到的那样,它不会优化0.0 * x[i],因为这样做的结果并不总是0.0。 (例如,可以是-0.0NaN。)我们在这里可以做的只是使用matchif来帮助优化器,以确保对于0.0 * y情况不起作用:

const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = 0.0;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            sum += x as f64 * y;
        }
    }
    sum
}

在发布模式下,展开循环,并内联X的值,导致大多数迭代由于不执行任何操作而被丢弃。生成的二进制文件(在x86_64上)剩下的唯一内容是:

foobar:
 xorpd   xmm0, xmm0
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 addsd   xmm0, qword, ptr, [rdi, +, 48]
 ret
  

(如@ lu-zero所建议,也可以使用filter_map来完成。它看起来像这样:X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum(),并给出完全相同的生成程序集。甚至没有{{ 1}},分别使用matchfiltermap。)

还不错!但是,此函数将计算.filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum(),因为0.0 - y[1] + y[6]sum开始,我们仅对其进行减法和加法。优化器再次不愿意优化0.0。我们可以从0.0开始而不是0.0来提供更多帮助:

None

结果是:

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = None;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            let p = x as f64 * y;
            sum = Some(sum.map_or(p, |s| s + p));
        }
    }
    sum.unwrap_or(0.0)
}

仅执行foobar: movsd xmm0, qword, ptr, [rdi, +, 48] subsd xmm0, qword, ptr, [rdi, +, 8] ret 。宾果游戏!

答案 1 :(得分:3)

您可以通过返回一个函数的宏来实现您的目标。

首先,编写不带宏的此函数。这需要固定数量的参数。

fn main() {
    println!("Hello, world!");
    let func = gen_sum([1,2,3]);
    println!("{}", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32 {
    move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
}

现在,完全重写它,因为先前的设计不能很好地用作宏。我们不得不放弃固定大小的数组,例如macros appear unable to allocate fixed-sized arrays

Rust Playground

fn main() {
    let func = gen_sum!(1,2,3);
    println!("{}", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

#[macro_export]
macro_rules! gen_sum {
    ( $( $x:expr ),* ) => {
        {
            let mut xs = Vec::new();
            $(
                xs.push($x);
            )*
            move |ys:Vec<i32>| {
                if xs.len() != ys.len() {
                    panic!("lengths don't match")
                }
                let mut total = 0;
                for i in 0 as usize .. xs.len() {
                    total += xs[i] * ys[i];
                }
                total
            } 
        }
    };
}

这是什么/应该做什么

在编译时,它会生成一个lambda。此lambda接受数字列表,并将其乘以在编译时生成的vec。我不认为这正是您追求的目标,因为它在编译时不会优化零值。您可以在编译时优化零,但必须在运行时通过检查零在x中的位置来确定要在y中乘以哪些元素,从而在运行时产生一定的成本。您甚至可以使用哈希集在固定时间内进行此查找过程。一般而言,它仍然可能不值得(我认为0并不那么普遍)。计算机胜于检测“要做的事”是“低效”然后跳过该事,而不是做“低效”的事。当它们执行的大部分操作“效率低下”时,这种抽象就会失效

跟进

那值得吗?它会缩短运行时间吗?我没有测量,但是与仅使用函数相比,理解和维护我编写的宏似乎不值得。编写一个宏来进行您所谈及的零优化可能会更令人讨厌。

答案 2 :(得分:3)

在许多情况下,编译器的优化阶段将为您解决这一问题。举个例子,这个函数定义

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [i32; 8]) -> i32 {
    X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
}

在x86_64上产生此程序集输出:

playground::dot_x:
    mov eax, dword ptr [rdi + 4]
    sub eax, dword ptr [rdi + 8]
    add eax, dword ptr [rdi + 20]
    sub eax, dword ptr [rdi + 28]
    ret

您将无法获得比该版本更优化的版本,因此仅以幼稚的方式编写代码是最佳的解决方案。尚不清楚编译器是否会为更长的向量展开循环,这可能会随编译器版本而改变。

对于浮点数,编译器通常无法执行上述所有优化,因为y中的数字不能保证是有限的,它们也可以是NaN,{{ 1}}或inf。因此,不能保证与-inf的乘法运算会再次导致0.0,因此编译器需要将乘法指令保留在代码中。通过使用0.0内在函数,您可以明确允许它假定所有数字都是有限的:

fmul_fast()

这将导致以下汇编代码:

#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| unsafe { fmul_fast(*i as f64, *j) }).sum()
}

这仍然在步骤之间多余地添加了零,但是我不希望这会导致现实CFD仿真产生任何可测量的开销,因为这样的仿真往往受内存带宽而不是CPU的限制。如果也要避免这些添加,则需要对添加使用playground::dot_x: # @playground::dot_x # %bb.0: xorpd xmm1, xmm1 movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero addsd xmm0, xmm1 subsd xmm0, qword ptr [rdi + 16] addsd xmm0, xmm1 addsd xmm0, qword ptr [rdi + 40] addsd xmm0, xmm1 subsd xmm0, qword ptr [rdi + 56] ret ,以使编译器进一步优化:

fadd_fast()

这将导致以下汇编代码:

#![feature(core_intrinsics)]
use std::intrinsics::{fadd_fast, fmul_fast};

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    let mut result = 0.0;
    for (&i, &j) in X.iter().zip(y.iter()) {
        unsafe { result = fadd_fast(result, fmul_fast(i as f64, j)); }
    }
    result
}

与所有优化一样,您应该从最易读和可维护的代码版本开始。如果性能成为问题,则应该分析代码并找到瓶颈。下一步,尝试改善基本方法,例如通过使用具有更好渐近复杂度的算法。只有这样,您才应该转向微优化,就像您在问题中建议的那样。

答案 3 :(得分:3)

如果您可以使用明确的filter_map()来节省#[inline(always)],则足以使编译器执行您想要的操作。