针对2d数组的每一行对另一个2d数组执行数学运算

时间:2019-07-31 01:01:47

标签: python numpy

我只能使用numpy导入。 我需要计算最接近的距离是测试集到训练集的距离。即在测试中找到最接近的距离(在训练数组中找到所有列表之间的距离),然后返回测试名称和训练名称。使用以下公式:

dist(x,y)=√((a-a2 )^2+(b-b2 )^2+(c-c2 )^2+(d-d2)^2 )

link到使用的数据并期望第一行。

这是我具有的代码,对于Train测试集中的第一行而言,该代码可以正常运行。我需要火车数组的每一行都经过变量q中的相同操作。 以下是我的输入

Training
a   b   c   d   name training
5   3   1.6 0.2 G
5   3.4 1.6 0.4 G
5.5 2.4 3.7 1   R
5.8 2.7 3.9 1.2 R
7.2 3.2 6   1.8 Y
6.2 2.8 4.8 1.8 Y

testing
a2  b2  c2  d2  name true
5   3.6 1.4 0.2 E
5.4 3.9 1.7 0.4 G
6.9 3.1 4.9 1.5 R
5.5 2.3 4   1.3 R
6.4 2.7 5.3 1.9 Y
6.8 3   5.5 2.1 Y
train = np.asarray(train)
test = np.asarray(test)
print('Train shape',train.shape)
print('test shape',test.shape)

train_1 = train[:,0:(train.shape[1])-1].astype(float)
test_1 = test[:,0:(test.shape[1])-1].astype(float)
print('Train '+'\n',train_1)
print('test '+'\`enter code here`n',test_1)
q=min((np.sqrt(np.sum((train_1[0,:]-test_1)**2,axis=1,keepdims=True))))

与整个测试阵列相比,我希望从训练行获得最接近的距离。使用此公式,使用公式的第一列火车将产生以下结果。然后我将返回G,E,因为这是最接近的2行。

1 个答案:

答案 0 :(得分:0)

您可以使用numpy.linalg.norm。这是一个例子:

>>> import numpy as np
>>> arr = np.array([1, 2, 3, 4])
>>> np.linalg.norm(arr)
5.477225575051661

5.477225575051661sqrt(1^2 + 2^2 + 3^2 + 4^2)

的结果
import numpy as np

train = np.array([[5, 3, 1.6, 0.2],
                  [5, 3.4, 1.6, 0.4],
                  [5.5, 2.4, 3.7, 1],
                  [5.8, 2.7, 3.9, 1.2],
                  [7.2, 3.2, 6, 1.8],
                  [6.2, 2.8, 4.8, 1.8]])

test = np.array([[5, 3.6, 1.4, 0.2],
                 [5.4, 3.9, 1.7, 0.4],
                 [6.9, 3.1, 4.9, 1.5],
                 [5.5, 2.3, 4, 1.3],
                 [6.4, 2.7, 5.3, 1.9],
                 [6.8, 3, 5.5, 2.1]])

# first get subtraction of each row of train to test
subtraction = train[:, None, :] - test[None, :, :]
# get distance from each train_row to test
s = np.linalg.norm(subtraction, axis=2, keepdims=True)
print(np.min(s, axis=1))
# get minimum
q = np.argmin(s, axis=1)
print("minimum indices:")
print(q)

输出:

[[0.63245553]
 [0.34641016]
 [0.43588989]
 [0.51961524]
 [0.73484692]
 [0.55677644]]
minimum indices:
[[0]
 [0]
 [3]
 [3]
 [5]
 [4]]