Numpy维度:带标量/矩阵的乘法

时间:2015-04-17 16:37:52

标签: python numpy

我有一个变量X,它可以是标量或数组。在任何一种情况下,我都想用{/ 1>扩展X

y = np.array([0.5, 1.5])

如果X是标量,那将是

(X*y).sum()

如果X是一个矩阵,比如2维,我想做

(X[..., np.newaxis]*y[np.newaxis, np.newaxis, ...]).sum()

我正在使用

创建y
try:
    ndim = X.ndim
except AttributeError:
    ndim = 0
y = np.array([0.5, 1.5], ndmin=ndim+1)

允许我使用X[..., np.newaxis]*y进行乘法:y现在与我计算中X的形状无关。但是,如果X[..., np.newaxis]是矩阵,我仍然需要X,如果是标量,则只需要X

如何在代码开头操作X以便我可以

(X*y).sum()

以及之后的类似操作,无论X是否为矩阵?

1 个答案:

答案 0 :(得分:2)

不需要np.newaxis中的y。没有它们,您可以获得相同的结果。对于X,我猜您可以执行以下操作:

if type(X) == np.ndarray:
    result = (X[..., np.newaxis] * y).sum()
else:
    result = (X * y).sum()

如果要将其概括为多个操作,可以在代码中的某处添加一行:

X = X[..., np.newaxis] if type(X) == np.ndarray else X

稍后只使用(X * y).sum(),因为它适用于X = numberX = ndarray

对于y,您不需要添加数字维度,numpy有smart broadcasting用于乘法。

一个完整的例子:

>>> x1 = 5
>>> x2 = np.random.rand(3,3)
>>> y  = np.array([0.5, 0.5])
>>> (x1 * y).sum() # works fine
>>> (x2[..., np.newaxis] * y).sum() # also works fine