我运行了KNN代码,但是“ NearestNeighbor”对象没有属性“ y_train”

时间:2019-05-14 08:16:37

标签: python

我从Stanford Computer视觉系统运行KNN代码,但是'NearestNeighbor'对象没有属性'y_train'

import numpy as np

class NearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, y):
        # Learn the training instances
        self.X_train = X
        self.y_train = y

    def predict(self, X_te):
        num = X_te.shape[0]
        y_pred = np.zeros(num, dtype = self.y_train.dtype)
        for i in range(num):
            distances = np.sum(np.abs(self.X_train - X_te[i, :]), axis=1)
            min_index = np.argmin(distances)
            y_pred[i] = self.y_train[min_index]
        return y_pred


def unpickle(file):
    import pickle
    with open(file,'rb') as fo:
        dict = pickle.load(fo,encoding='bytes')
    return dict

data_train = unpickle(r'I:\course\Computer_vision\data\cifar_pic\cifar_10_batches_py\data_batch_2')
data_test = unpickle(r'I:\course\Computer_vision\data\cifar_pic\cifar_10_batches_py\test_batch')
# train data
X_train = data_train[b'data']
y_train = data_train[b'labels']
# test data
X_test = data_test[b'data']
y_test = data_test[b'labels']
# call NN
NearestNeighbor().train(X_train,y_train)
y_pred = NearestNeighbor().predict(X_test)
print(y_pred)

  

AttributeError:“ NearestNeighbor”对象没有属性“ y_train”

1 个答案:

答案 0 :(得分:0)

您的属性名称是: X y X_train y_train 类中的变量 NearestNeighbor