文件名train_tripletloss.py显示IndexError:索引2超出了大小为2的轴0的范围

时间:2018-08-29 06:53:58

标签: python-3.x

我有一个只有2个类别的数据集,我试图按照https://github.com/davidsandberg/facenet/wiki/Triplet-loss-training

使用train_tripletloss.py训练数据集

请帮助我解决问题...

发生错误的代码部分如下所示:

def sample_people(dataset, people_per_batch, images_per_person):
    nrof_images = people_per_batch * images_per_person

    # Sample classes from the dataset
    nrof_classes = len(dataset)
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)

    i = 0
    image_paths = []
    num_per_class = []
    sampled_class_indices = []
    # Sample images from these classes until we have enough
    while len(image_paths)<nrof_images:
        class_index = class_indices[i]
        nrof_images_in_class = len(dataset[class_index])
        image_indices = np.arange(nrof_images_in_class)
        np.random.shuffle(image_indices)
        nrof_images_from_class = min(nrof_images_in_class, images_per_person, nrof_images-len(image_paths))
        idx = image_indices[0:nrof_images_from_class]
        image_paths_for_class = [dataset[class_index].image_paths[j] for j in idx]
        sampled_class_indices += [class_index]*nrof_images_from_class
        image_paths += image_paths_for_class
        num_per_class.append(nrof_images_from_class)
        i+=1

    return image_paths, num_per_class

我认为错误发生在class_index = class_indices[i]

0 个答案:

没有答案