check_arrays()限制scikit-learn中的数组维度?

时间:2014-05-21 17:28:17

标签: python-2.7 scikit-learn

我想使用scikit-learn X0.15中提供的sklearn.learning_curves.py。在我克隆了这个版本之后,几个函数不再起作用,因为check_arrays()将数组的维度限制为2。

>>> from sklearn import metrics 
>>> from sklearn.cross_validation import train_test_split 
>>> import numpy as np
>>> X = np.random.random((10,2,2,2))
>>> y = np.random.random((10,2,2,2))
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=3)
>>> error "Found array with dim 4d. Expected <= 2"

使用相同的X和y我得到同样的错误。

>>> mse = metrics.mean_squared_error
>>> mse(X,y)
>>> error "Found array with dim 4d. Expected <= 2"

如果我去sklearn.utils.validation.py并注释掉第272,273和274行,如下所示,一切正常。

# if array.ndim >= 3:
#     raise ValueError("Found array with dim %d. Expected <= 2" %
#                      array.ndim)

为什么数组的尺寸限制为2?

1 个答案:

答案 0 :(得分:0)

因为scikit-learn对所有要素数据使用二维约定(n_samples×n_features)。如果任何函数或方法允许更高的d数组通过,那通常只是疏忽,你不能真正依赖它。