代码中的修改将其分类为少于100个类

时间:2017-06-01 08:02:00

标签: python-2.7 tensorflow

我正在处理分类问题。以下代码仅在将数据划分为100个类时才有效。如果我将我的数据分成少于100个类,它会给我一个错误,即索引超出范围。我的训练标签是85000,测试标签是15000可以请一些人告诉我它是在给出这个错误以及如何修复它?

def dense_to_one_hot(labels_dense, num_classes):
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot

def extract_labels(labels,num_classes, one_hot=False):
    if one_hot :
        return dense_to_one_hot(labels,num_classes)
    return labels

0 个答案:

没有答案