tflearn标签编码与大量的类

时间:2017-04-07 13:24:26

标签: python tflearn

我正在尝试调整Convolutional Neural Net example of tflearn以使用~12000个不同的类标签和超过100万个训练示例进行分类。当对其进行单热编码时,标签的数量在内存消耗方面显然是个问题。我首先将字符串标签映射到连续整数,然后将它们作为列表传递给to_categorical()函数。以下代码导致MemoryError:

trainY = to_categorical(trainY, nb_classes=n_classes)

我是否必须像这样对标签进行编码,还是应该使用与交叉熵不同的损失函数?我可以用tflearn批量训练 - 我可以将生成器传递给DNN.fit()函数吗?

感谢您的任何建议!

1 个答案:

答案 0 :(得分:2)

在回归图层link中,您可以指定送入的标签应在运行时进行一次热编码

tflearn.layers.regression(incoming_net,
                          loss = 'categorical_crossentropy',
                          batch_size = 64,
                          to_one_hot = True,
                          n_classes = 12000)

这样就不会出现内存错误,因为标签会在训练时分批编码。