矩阵求幂的两种方法的比较

时间:2020-11-04 18:00:52

标签: performance numpy jax

我有两种方法对jnp = jax.numpy中的矩阵求幂。一种 简单的一个:

jnp.exp(-X/reg)

并采取一些其他措施:

def exp_reg(X, reg):
    K = jnp.empty_like(X)
    K = jnp.divide(X, -reg)
    return jnp.exp(K)

但是,当我测试它们时:

%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()

尽管表面上有一些额外的开销,第二种方法却跑赢了。我运行的%timeit的尺寸为2000 x 2000:

7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

为什么会这样?

1 个答案:

答案 0 :(得分:1)

这里的区别是操作顺序。

X中,您求反reg的每个条目,然后将结果的每个条目除以X。那是遍历数组exp_reg的两遍。

reg中,您要求反X(可能是一个标量值?),然后将X除以结果。那是X上的一遍。

如果X大,由于jit上的多次通过,我希望第一种方法比第二种方法稍慢。

幸运的是,由于您使用的是JAX,因此可以from jax import jit import jax.numpy as jnp import numpy as np def exp_reg1(X, reg): return jnp.exp(-X/reg) def exp_reg2(X, reg): K = jnp.divide(X, -reg) return jnp.exp(K) X = jnp.array(np.random.rand(1000, 1000)) reg = 2.0 %timeit exp_reg1(X, reg) # 100 loops, best of 3: 3.17 ms per loop %timeit exp_reg2(X, reg) # 100 loops, best of 3: 2.2 ms per loop # Trigger compilation jit(exp_reg1)(X, reg) jit(exp_reg2)(X, reg) %timeit jit(exp_reg1)(X, reg) # 1000 loops, best of 3: 1.92 ms per loop %timeit jit(exp_reg2)(X, reg) # 100 loops, best of 3: 1.84 ms per loop 编译代码,在这种情况下,XLA通常可以在类似的这些操作顺序上进行优化。确实,对于您的两个功能,编译可以消除差异:

K

(旁注:在将运算结果分配给相同名称的变量之前,没有理由预先分配空数组document.getElementById("form1").addEventListener("submit", function(e) { e.preventDefault(); // remove if you DO want the form to submit document.title = document.getElementById("customerName").value + " " + "Caller"; })