我正在使用以下函数进行一次热表示。
def dense_to_one_hot(labels_dense, num_classes=2):
num_labels = labels_dense.shape[0]
print("num_labels")
index_offset = numpy.arange(num_labels) * num_classes
print("index_offset: ",index_offset )
labels_one_hot = numpy.zeros((num_labels, num_classes))
print(labels_one_hot.shape)
print(labels_dense.ravel())
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
print(labels_one_hot)
return labels_one_hot
我遇到了不支持的迭代器索引的错误。
train_label_1 = np.load(os.path.join("Input","train_label.npy"))
train_label = train_label_1[0:10]
print("train_label")
print(train_label)
labels_one_hot = dense_to_one_hot(train_label)
print("label_one_hot")
print(labels_one_hot)
当我打印train_label时,得到以下结果。
[0. 1. 0. 1. 0. 0. 1. 0. 0. 1.]
当我将其转换为numpy数组,然后调用dense_to_one_hot
函数时,它工作正常,但是当我仅从文件中加载火车标签并调用此函数时,它就开始给我带来错误unsupported iterator index
。可以请一个人指导我如何解决它。