为阵列矢量化NumPy代码

时间:2018-01-08 19:37:08

标签: python arrays numpy vectorization

假设我有一个函数,它将NumPy数组x和NumPy数组a作为输入。例如:

def f(x, a): 
    """
    Compute the sum of all elements in x and the dot-product between x and a.

    Args:
        x, a: NumPy arrays.
    """
    assert len(x) == len(a) 

    return np.sum(x), np.dot(x, a)

x = np.ones(5)
a = np.ones(5)

f(x, a)

一切正常。 但是,假设现在我希望函数在x是二维数组时也能工作,方法是将sum函数应用于每一列。 我可以重写:

def f(x, a): 
    """
    Compute the column-sum of x and the dot-product between x and a.

    Args:
        x, a: NumPy arrays.
    """
    assert x.shape[0] == len(a) 

    return np.sum(x, axis=1), np.dot(a, x)

x = np.ones((10, 5))
a = np.ones(5)

f(x, a)

当x是二维数组时会起作用。 但是,由于x = np.ones(5)assert x.shape[1] ...由于np.sum(x, axis=1)的形状为x,我们尝试将其应用于(5,)时,重写函数将失败}。 我如何才能为这两种情况制作相同的代码? 我能想到的唯一方法就是做一个丑陋的检查和重塑:

def f(x, a): 
    """
    Compute the column-sum of x and the dot-product between x and a.

    Args:
        x, a: NumPy arrays.
    """
    if len(x.shape == 1):
        x = x.reshape((len(x), 1))

    assert x.shape[0] == len(a) 

    return np.sum(x, axis=1), np.dot(a, x)

但结果代码既不清晰也不高效......有没有更好的方法来矢量化代码?

0 个答案:

没有答案