一种热表示在两种情况下的行为不同

时间:2018-09-06 03:06:52

标签: python numpy

我正在使用以下函数进行一次热表示。

def dense_to_one_hot(labels_dense, num_classes=2):
    num_labels = labels_dense.shape[0]
    print("num_labels")
    index_offset = numpy.arange(num_labels) * num_classes
    print("index_offset: ",index_offset )
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    print(labels_one_hot.shape)
    print(labels_dense.ravel())
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    print(labels_one_hot)
    return labels_one_hot

我遇到了不支持的迭代器索引的错误。

train_label_1 = np.load(os.path.join("Input","train_label.npy"))
    train_label = train_label_1[0:10]
    print("train_label")
    print(train_label)
    labels_one_hot = dense_to_one_hot(train_label)
    print("label_one_hot")
    print(labels_one_hot)

当我打印train_label时,得到以下结果。

[0. 1. 0. 1. 0. 0. 1. 0. 0. 1.]

当我将其转换为numpy数组,然后调用dense_to_one_hot函数时,它工作正常,但是当我仅从文件中加载火车标签并调用此函数时,它就开始给我带来错误unsupported iterator index。可以请一个人指导我如何解决它。

0 个答案:

没有答案