我有一个只有2个类别的数据集,我试图按照https://github.com/davidsandberg/facenet/wiki/Triplet-loss-training
使用train_tripletloss.py训练数据集请帮助我解决问题...
发生错误的代码部分如下所示:
def sample_people(dataset, people_per_batch, images_per_person):
nrof_images = people_per_batch * images_per_person
# Sample classes from the dataset
nrof_classes = len(dataset)
class_indices = np.arange(nrof_classes)
np.random.shuffle(class_indices)
i = 0
image_paths = []
num_per_class = []
sampled_class_indices = []
# Sample images from these classes until we have enough
while len(image_paths)<nrof_images:
class_index = class_indices[i]
nrof_images_in_class = len(dataset[class_index])
image_indices = np.arange(nrof_images_in_class)
np.random.shuffle(image_indices)
nrof_images_from_class = min(nrof_images_in_class, images_per_person, nrof_images-len(image_paths))
idx = image_indices[0:nrof_images_from_class]
image_paths_for_class = [dataset[class_index].image_paths[j] for j in idx]
sampled_class_indices += [class_index]*nrof_images_from_class
image_paths += image_paths_for_class
num_per_class.append(nrof_images_from_class)
i+=1
return image_paths, num_per_class
我认为错误发生在class_index = class_indices[i]