np.newaxis与Numba nopython

时间:2016-08-04 07:21:31

标签: numpy numba numpy-broadcasting

有没有办法将np.newaxis与Numba nopython一起使用?为了应用广播功能而不在python上回退?

例如

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

由于

3 个答案:

答案 0 :(得分:2)

您可以使用重塑来完成此操作,看起来目前不支持std::string line; getline(cin, line); 索引。请注意,这可能不会比使用python快得多,因为它已经被矢量化了。

[:, None]

答案 1 :(得分:1)

这可以使用最新版本的Numba(0.27)和numpy stride_tricks来完成。你需要小心这一点,它有点难看。阅读as_strided的{​​{3}},以确保您了解正在进行的操作,因为这不是安全的"因为它没有检查形状或步幅。

import numpy as np
import numba as nb

a = np.random.randn(20, 10)
b = np.random.randn(20) 
c = np.random.randn(10)

def toto(a, b, c):

    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

@nb.jit(nopython=True)
def toto2(a, b, c):
    _b = np.lib.stride_tricks.as_strided(b, shape=(b.shape[0], 1), strides=(b.strides[0], 0))
    _c = np.lib.stride_tricks.as_strided(c, shape=(1, c.shape[0]), strides=(0, c.strides[0]))
    d = a - _b * _c

    return d

x = toto(a,b,c)
y = toto2(a,b,c)
print np.allclose(x, y) # True

答案 2 :(得分:0)

在我的盒子中(numba:0.35,numpy:1.14.0)expand_dims正常工作:

将numpy导入为np 从numba import jit

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - np.expand_dims(b, -1) * np.expand_dims(c, 0)
    return d

我们当然可以使用广播省略第二个expand_dims