如何确保numpy数组为2D行向量或列向量?

时间:2019-01-28 21:13:23

标签: python numpy

是否有一个numpy函数来确保1D或2D数组是列向量还是行向量?

例如,我具有以下载体/列表之一。将任何输入转换为列向量的最简单方法是什么?

x1 = np.array(range(5))
x2 = x1[np.newaxis, :]
x3 = x1[:, np.newaxis]

def ensureCol1D(x):
    # The input is either a 0D list or 1D.
    assert(len(x.shape)==1 or (len(x.shape)==2 and 1 in x.shape))
    x = np.atleast_2d(x)
    n = x.size
    print(x.shape, n)
    return x if x.shape[0] == n else x.T

assert(ensureCol1D(x1).shape == (x1.size, 1))
assert(ensureCol1D(x2).shape == (x2.size, 1))
assert(ensureCol1D(x3).shape == (x3.size, 1))

除了编写自己的函数ensureCol1D之外,numpy中是否已有类似的东西可以确保向量成为列?

1 个答案:

答案 0 :(得分:2)

您的问题本质上是如何将数组转换为“列”,列是2D数组,行长为1。这可以通过ndarray.reshape(-1, 1)完成。

这意味着您reshape的数组的行长为1,让numpy推断行数/列长。

x1 = np.array(range(5))
print(x1.reshape(-1, 1))

输出:

array([[0],
   [1],
   [2],
   [3],
   [4]])

重塑x2x3时会得到相同的输出。此外,这也适用于n维数组:

x = np.random.rand(1, 2, 3)
print(x.reshape(-1, 1).shape)

输出:

(6, 1)

最后,这里唯一缺少的是您做出一些断言,以确保不会转换无法正确转换的数组。您要进行的主要检查是,形状中非整数的数量小于或等于1。这可以通过以下方式完成:

assert sum(i != 1 for i in x1.shape) <= 1

此检查与.reshape一起让您将逻辑应用于所有numpy数组。