np.eye()内存错误

时间:2018-04-30 17:40:06

标签: python python-3.x numpy

我正在尝试在批量图像及其相应标签上训练机器学习模型。标签在送入网络之前必须进行单热编码。

这就是训练循环的样子:

# epochs = 500
# iterations = 50
# nclasses = 15,000
# batch_size = 64

for e in range(epochs):
    print("EPOCH %d" % e)
    for i in range(iterations):
        # batch() produces an array of labels of size (64,)
        imgs, lbls = batch()
        # Batch-wise 1-Hot Encoding
        lbls = np.eye(nclasses)[lbls] # <-- Memory Error here!
        # More processing

每当我尝试运行它时,我会在lbls = np.eye(nclasses)[lbls]中得到一个内存错误,而不是马上,但是在第60个时代的中间左右。这个数字是相当大的(15k),我玩过不同的批量大小,但所有这些都是在发生内存错误时推迟。

batch()执行以下操作:

def batch():
# Provides a random array of images and labels of size batch_size
# labels_train is an array of shape (1217958,)
    p = np.random.choice(np.array(range(train_total)), batch_size, replace=False)
    imgs = create_images(list(inputs_train[p]))
    return imgs, labels_train[p]

我真的很感激,如果有人能告诉我为什么这会发生在中间而不是马上,如果你可以建议一个解决方法!非常感谢你!

0 个答案:

没有答案