我有一个变量X
,它可以是标量或数组。在任何一种情况下,我都想用{/ 1>扩展X
y = np.array([0.5, 1.5])
如果X
是标量,那将是
(X*y).sum()
如果X
是一个矩阵,比如2维,我想做
(X[..., np.newaxis]*y[np.newaxis, np.newaxis, ...]).sum()
我正在使用
创建y
try:
ndim = X.ndim
except AttributeError:
ndim = 0
y = np.array([0.5, 1.5], ndmin=ndim+1)
允许我使用X[..., np.newaxis]*y
进行乘法:y
现在与我计算中X
的形状无关。但是,如果X[..., np.newaxis]
是矩阵,我仍然需要X
,如果是标量,则只需要X
。
如何在代码开头操作X
以便我可以
(X*y).sum()
以及之后的类似操作,无论X
是否为矩阵?
答案 0 :(得分:2)
不需要np.newaxis
中的y
。没有它们,您可以获得相同的结果。对于X
,我猜您可以执行以下操作:
if type(X) == np.ndarray:
result = (X[..., np.newaxis] * y).sum()
else:
result = (X * y).sum()
如果要将其概括为多个操作,可以在代码中的某处添加一行:
X = X[..., np.newaxis] if type(X) == np.ndarray else X
稍后只使用(X * y).sum()
,因为它适用于X = number
和X = ndarray
。
对于y
,您不需要添加数字维度,numpy有smart broadcasting用于乘法。
一个完整的例子:
>>> x1 = 5
>>> x2 = np.random.rand(3,3)
>>> y = np.array([0.5, 0.5])
>>> (x1 * y).sum() # works fine
>>> (x2[..., np.newaxis] * y).sum() # also works fine