馈送(大)numpy数组时,train_test_split导致RAM崩溃

时间:2019-08-14 11:24:36

标签: numpy scikit-learn

train_test_split方法来自Scikit-learn,当向X输入形状为(5621, 224, 224, 3)y的形状为(5621, 3)的numpy数组时,RAM崩溃并终止执行。 / p>

  • X包含5621张224x224 RGB数据的图像。
  • y包含5621个3类的OneHot编码标签。

我当时正在加载一些图像作为训练数据来馈送卷积神经网络,但是当分成训练和测试数据时,它崩溃了。是否有另一种选择来加载图像以避免这种内存消耗?

复制步骤:

import numpy as np
from sklearn.model_selection import train_test_split

# Generate dummy data
X = np.random.random((5621, 224, 224, 3))
y = np.random.randint(3, size=(5621, 3))

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, shuffle=True) # Breaks here

我希望将3766个火车样本和1855个测试样本作为输出,但是它将发送SIGKILL(和100%RAM使用率)并退出执行。

1 个答案:

答案 0 :(得分:0)

您确定它在拆分方法上还是在以前已经崩溃了吗?

您也可以手动拆分它:

class HRXNumberField: NSTextField {
    override func textDidChange(_ notification: Notification) {
        let invalidChars = NSCharacterSet(charactersIn: "1234567890.").inverted
        let validChars = stringValue.components(separatedBy: invalidChars)
        stringValue = validChars.joined()
        let separatedByDot = stringValue.components(separatedBy: ".")
        if separatedByDot.count > 2 {
            stringValue = separatedByDot[0] + "." + separatedByDot.dropFirst().joined()
        }
        print(self.stringValue)
    }
}

您的数据已经是随机的,因此应该不是排序问题。