我正在处理分类问题。以下代码仅在将数据划分为100个类时才有效。如果我将我的数据分成少于100个类,它会给我一个错误,即索引超出范围。我的训练标签是85000,测试标签是15000可以请一些人告诉我它是在给出这个错误以及如何修复它?
def dense_to_one_hot(labels_dense, num_classes):
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def extract_labels(labels,num_classes, one_hot=False):
if one_hot :
return dense_to_one_hot(labels,num_classes)
return labels