具有多个元素的数组的真值不明确吗?

时间:2019-11-20 09:08:26

标签: python numpy machine-learning

我遇到了无法对数组进行排序的问题。我收到此错误The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(),为什么会这样呢?我不明白是因为打领带吗?我已经在这个问题上待了一段时间了,无法解决。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn import datasets, metrics, svm
from sklearn.model_selection import train_test_split
from collections import Counter
from math import sqrt

#import number data
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images, digits.target))
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
data_train, data_test, label_train, label_test = train_test_split(data, digits.target, test_size=0.2)

def euclidean_distance(first, second):
    distance = 0.0
    for i in range(64):
        distance += (first[i] - second[i])**2
    return np.sqrt(distance)

def get_neighbors(train_set, test_set, num_neighbors):
    distances = list()
    for test_set in train_set:
        dist = euclidean_distance(test_set, train_set)
        distances.append((train_set, dist))
    np.sort(distances)
    neighbors = list()
    for i in range(num_neighbors):
        neighbors.append(distances[i][0])
    return neighbors

results = get_neighbors(data_train, data_test, 100 )

1 个答案:

答案 0 :(得分:0)

在这一行:

np.sort(distances)

distances是一个元组列表-每个元组包含一对numpy数组。例如:

>>> distances[0]
(array([[ 0.,  0.,  6., ...,  0.,  0.,  0.],
        [ 0.,  3., 13., ...,  0.,  0.,  0.],
        [ 0.,  0.,  0., ..., 16.,  7.,  0.],
        ...,
        [ 0.,  0.,  1., ..., 16.,  5.,  0.],
        [ 0.,  1., 13., ...,  1.,  0.,  0.],
        [ 0.,  0.,  1., ..., 14.,  6.,  0.]]),
 array([54.40588203, 51.7107339 , 58.72818744, 83.80930736, 77.37570678,
        58.25804665, 54.18486874, 54.47935389, 54.40588203, 52.13444159,
        73.54590403, 87.35559513, 79.01898506, 66.55824517, 54.61684722,
        54.56189146, 54.40588203, 50.55689864, 78.65748534, 74.60562981,
        74.37741593, 72.70488292, 55.08175742, 54.41507144, 54.40588203,
        49.43682838, 76.51143705, 70.9577339 , 75.94076639, 65.90902821,
        56.90342696, 54.40588203, 54.40588203, 55.06359959, 73.06161783,
        73.71566998, 82.50454533, 70.        , 54.        , 54.40588203,
        54.40588203, 53.10367219, 75.16648189, 78.7273777 , 74.06753675,
        67.00746227, 55.6596802 , 54.40588203, 54.40588203, 52.5832673 ,
        68.82586723, 80.74651695, 74.24957912, 72.9725976 , 56.59505279,
        53.8702144 , 54.40588203, 52.06726419, 61.10646447, 83.5463943 ,
        84.92938243, 61.83041323, 53.38539126, 54.2678542 ]))

错误是因为np.sort不知道如何处理。