我正在尝试使用public repository中的代码来训练带有一组图像的kNN模型。它最初用于处理群集之间所有图像的相似性。但是我想使用一个新图像(模型中未包含),并从原始群集中获得最相似的图像。
这是训练原始kNN的代码
for f in os.listdir(path):
# Process filename
filename = os.path.splitext(f) # filename in directory
filename_full = os.path.join(path,f) # full path filename
head, ext = filename[0], filename[1]
if ext.lower() not in [".jpg", ".jpeg"]:
continue
# Read image file
img = image.load_img(filename_full, target_size=(224, 224)) #
load
imgs.append(np.array(img)) # image
filename_heads.append(head) # filename head
# Pre-process for model input
img = process_image(img)
features = model.predict(img).flatten() # features
eX.append(features) # append feature extractor
filename_heads.append(head)
X = np.array(eX) # feature vectors
imgs = np.array(imgs) # images
n_neighbours = 5 + 1
knn = kNN() # kNN model
knn.compile(n_neighbors=n_neighbours, algorithm="brute", metric="cosine")
knn.fit(X)
这是我的代码,用于查询新图像并在原始群集中查找相似图像
#previously I read the image from an url and put it in img variable
img = image.load_img('db/temp.jpg', target_size=(224, 224)) # load
img = image.img_to_array(img) # convert to array
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
img_features = model.predict(img).flatten() # features
distances, indices = knn.predict(img_features)
问题是我收到“ IndexError:元组索引超出范围 运行knn.predict(new_img_features)时出现错误。我已经查看了img_features的形状和类型,并且它们都是相同的,所以我真的不知道为什么会出现此错误。也许是因为这里使用的kNN不是分类器,但我不知道如何对其进行调整才能起作用。
Full code link,以防万一您想签出来。
答案 0 :(得分:1)
问题是我必须以这种方式传递矩阵:
distances, indices = knn.predict(np.array([img_features]))