我遇到了无法对数组进行排序的问题。我收到此错误The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
,为什么会这样呢?我不明白是因为打领带吗?我已经在这个问题上待了一段时间了,无法解决。
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split
from collections import Counter
from math import sqrt
#import number data
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
data_train, data_test, label_train, label_test = train_test_split(data, digits.target, test_size=0.2)
def euclidean_distance(first, second):
distance = 0.0
for i in range(64):
distance += (first[i] - second[i])**2
return np.sqrt(distance)
def get_neighbors(train_set, test_set, num_neighbors):
distances = list()
for test_set in train_set:
dist = euclidean_distance(test_set, train_set)
distances.append((train_set, dist))
np.sort(distances)
neighbors = list()
for i in range(num_neighbors):
neighbors.append(distances[i][0])
return neighbors
results = get_neighbors(data_train, data_test, 100 )
答案 0 :(得分:0)
在这一行:
np.sort(distances)
distances是一个元组列表-每个元组包含一对numpy数组。例如:
>>> distances[0]
(array([[ 0., 0., 6., ..., 0., 0., 0.],
[ 0., 3., 13., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 16., 7., 0.],
...,
[ 0., 0., 1., ..., 16., 5., 0.],
[ 0., 1., 13., ..., 1., 0., 0.],
[ 0., 0., 1., ..., 14., 6., 0.]]),
array([54.40588203, 51.7107339 , 58.72818744, 83.80930736, 77.37570678,
58.25804665, 54.18486874, 54.47935389, 54.40588203, 52.13444159,
73.54590403, 87.35559513, 79.01898506, 66.55824517, 54.61684722,
54.56189146, 54.40588203, 50.55689864, 78.65748534, 74.60562981,
74.37741593, 72.70488292, 55.08175742, 54.41507144, 54.40588203,
49.43682838, 76.51143705, 70.9577339 , 75.94076639, 65.90902821,
56.90342696, 54.40588203, 54.40588203, 55.06359959, 73.06161783,
73.71566998, 82.50454533, 70. , 54. , 54.40588203,
54.40588203, 53.10367219, 75.16648189, 78.7273777 , 74.06753675,
67.00746227, 55.6596802 , 54.40588203, 54.40588203, 52.5832673 ,
68.82586723, 80.74651695, 74.24957912, 72.9725976 , 56.59505279,
53.8702144 , 54.40588203, 52.06726419, 61.10646447, 83.5463943 ,
84.92938243, 61.83041323, 53.38539126, 54.2678542 ]))
错误是因为np.sort
不知道如何处理。