假设我有一个函数,它将NumPy数组x
和NumPy数组a
作为输入。例如:
def f(x, a):
"""
Compute the sum of all elements in x and the dot-product between x and a.
Args:
x, a: NumPy arrays.
"""
assert len(x) == len(a)
return np.sum(x), np.dot(x, a)
x = np.ones(5)
a = np.ones(5)
f(x, a)
一切正常。
但是,假设现在我希望函数在x
是二维数组时也能工作,方法是将sum函数应用于每一列。
我可以重写:
def f(x, a):
"""
Compute the column-sum of x and the dot-product between x and a.
Args:
x, a: NumPy arrays.
"""
assert x.shape[0] == len(a)
return np.sum(x, axis=1), np.dot(a, x)
x = np.ones((10, 5))
a = np.ones(5)
f(x, a)
当x是二维数组时会起作用。
但是,由于x = np.ones(5)
和assert x.shape[1] ...
由于np.sum(x, axis=1)
的形状为x
,我们尝试将其应用于(5,)
时,重写函数将失败}。
我如何才能为这两种情况制作相同的代码?
我能想到的唯一方法就是做一个丑陋的检查和重塑:
def f(x, a):
"""
Compute the column-sum of x and the dot-product between x and a.
Args:
x, a: NumPy arrays.
"""
if len(x.shape == 1):
x = x.reshape((len(x), 1))
assert x.shape[0] == len(a)
return np.sum(x, axis=1), np.dot(a, x)
但结果代码既不清晰也不高效......有没有更好的方法来矢量化代码?