使用原始群集图片之外的查询kNN

时间:2018-08-14 12:21:21

标签: python scikit-learn computer-vision

我正在尝试使用public repository中的代码来训练带有一组图像的kNN模型。它最初用于处理群集之间所有图像的相似性。但是我想使用一个新图像(模型中未包含),并从原始群集中获得最相似的图像。

这是训练原始kNN的代码

for f in os.listdir(path):

    # Process filename
    filename = os.path.splitext(f)  # filename in directory
    filename_full = os.path.join(path,f)  # full path filename
    head, ext = filename[0], filename[1]
    if ext.lower() not in [".jpg", ".jpeg"]:
        continue

    # Read image file
    img = image.load_img(filename_full, target_size=(224, 224))  # 
    load
    imgs.append(np.array(img))  # image
    filename_heads.append(head)  # filename head

    # Pre-process for model input
    img = process_image(img)
    features = model.predict(img).flatten()  # features
    eX.append(features)  # append feature extractor

filename_heads.append(head)

X = np.array(eX)  # feature vectors
imgs = np.array(imgs)  # images

n_neighbours = 5 + 1 
knn = kNN()  # kNN model
knn.compile(n_neighbors=n_neighbours, algorithm="brute", metric="cosine")
knn.fit(X)

这是我的代码,用于查询新图像并在原始群集中查找相似图像

#previously I read the image from an url and put it in img variable
img = image.load_img('db/temp.jpg', target_size=(224, 224))  # load 
img = image.img_to_array(img)  # convert to array
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
img_features = model.predict(img).flatten()  # features
distances, indices = knn.predict(img_features)

问题是我收到“ IndexError:元组索引超出范围 运行knn.predict(new_img_features)时出现错误。我已经查看了img_features的形状和类型,并且它们都是相同的,所以我真的不知道为什么会出现此错误。也许是因为这里使用的kNN不是分类器,但我不知道如何对其进行调整才能起作用。

Full code link,以防万一您想签出来。

1 个答案:

答案 0 :(得分:1)

问题是我必须以这种方式传递矩阵:

distances, indices = knn.predict(np.array([img_features]))