将结构化数组转换为numpy数组以与Scikit-Learn一起使用

时间:2018-03-03 14:55:44

标签: python arrays numpy

我很难将使用np.genfromtxt从CSV加载的结构化数组转换为np.array,以便将数据拟合到Scikit-Learn估算工具中。问题是在某些时候会发生从结构化数组到常规数组的转换,从而产生ValueError: can't cast from structure to non-structure。很长一段时间以来,我一直在使用.view来执行转换,但这导致NumPy发布了一些弃用警告。代码如下:

import numpy as np
from sklearn.ensemble import GradientBoostingClassifier

data = np.genfromtxt(path, dtype=float, delimiter=',', names=True)

target = "occupancy"
features = [
    "temperature", "relative_humidity", "light", "C02", "humidity"
]

# Doesn't work directly
X = data[features]
y = data[target].astype(int)

clf = GradientBoostingClassifier(random_state=42)
clf.fit(X, y)

引发的例外是:ValueError: Can't cast from structure to non-structure, except if the structure only has a single field.

我的第二次尝试是使用如下视图:

# View is raising deprecation warnings
X = data[features]
X = X.view((float, len(X.dtype.names)))
y = data[target].astype(int)

哪种方法有效并且完全符合我的要求(我不需要数据的副本),但会导致弃用警告:

FutureWarning: Numpy has detected that you may be viewing or writing to 
an array returned by selecting multiple fields in a structured array.

This code may break in numpy 1.15 because this will return a view 
instead of a copy -- see release notes for details.

目前,我们正在使用tolist()将结构化数组转换为列表,然后转换为np.array。这有效,但看起来非常低效:

# Current method (efficient?)
X = np.array(data[features].tolist())
y = data[target].astype(int)

必须有更好的方法,我很欣赏任何建议。

注意:此示例的数据来自UCI ML Occupancy Repository,数据显示如下:

array([(nan, 23.18, 27.272 , 426.  ,  721.25, 0.00479299, 1.),
       (nan, 23.15, 27.2675, 429.5 ,  714.  , 0.00478344, 1.),
       (nan, 23.15, 27.245 , 426.  ,  713.5 , 0.00477946, 1.), ...,
       (nan, 20.89, 27.745 , 423.5 , 1521.5 , 0.00423682, 1.),
       (nan, 20.89, 28.0225, 418.75, 1632.  , 0.00427949, 1.),
       (nan, 21.  , 28.1   , 409.  , 1864.  , 0.00432073, 1.)],
      dtype=[('datetime', '<f8'), ('temperature', '<f8'), ('relative_humidity', '<f8'), 
             ('light', '<f8'), ('C02', '<f8'), ('humidity', '<f8'), ('occupancy', '<f8')])

2 个答案:

答案 0 :(得分:3)

.copy()添加到data[features]

X = data[features].copy()
X = X.view((float, len(X.dtype.names)))

并且FutureWarning消息消失了。

这比首先转换为列表更有效。

答案 1 :(得分:1)

如果您可以先将数据读入普通的NumPy数组(省略names参数),则可以避免复制的需要:

data = np.genfromtxt(path, dtype=float, delimiter=',', skip_header=1)

然后(幸运的是),X由除第一列和最后一列之外的所有列组成(即省略datetimeoccupancy列)。因此,我们可以将Xy表示为切片:

X = data[:, 1:-1]
y = data[:, -1].astype(int)

然后我们可以轻松地将这些功能传递给scikit-learn函数:

clf = GradientBoostingClassifier(random_state=42)
clf.fit(X, y)

并且,如果我们愿意,我们可以在之后将纯NumPy数组视为结构化数组:

features = ["temperature", "relative_humidity", "light", "C02", "humidity"]
X = X.ravel().view([(field, X.dtype.type) for field in features])

不幸的是,这种解决方法依赖于X可以表达为切片 - 例如,如果occupancy出现在其他要素列之间,我们将无法避免复制。这也意味着您必须使用X而不是更易于理解的X = data[:, 1:-1]来定义X = data[features]

import numpy as np
from sklearn.ensemble import GradientBoostingClassifier

data = np.genfromtxt(path, dtype=float, delimiter=',', skip_header=1)

X = data[:, 1:-1]
y = data[:, -1].astype(int)

clf = GradientBoostingClassifier(random_state=42)
clf.fit(X, y)

features = ["temperature", "relative_humidity", "light", "C02", "humidity"]
X = X.ravel().view([(field, X.dtype.type) for field in features])

如果必须从结构化数组开始,那么hpaulj's answer显示如何view/reshape/slice结构化数组在不复制的情况下获取普通数组:

import numpy as np
nan = np.nan
data = np.array([(nan, 23.18, 27.272 , 426.  ,  721.25, 0.00479299, 1.),
       (nan, 23.15, 27.2675, 429.5 ,  714.  , 0.00478344, 1.),
       (nan, 23.15, 27.245 , 426.  ,  713.5 , 0.00477946, 1.), 
       (nan, 20.89, 27.745 , 423.5 , 1521.5 , 0.00423682, 1.),
       (nan, 20.89, 28.0225, 418.75, 1632.  , 0.00427949, 1.),
       (nan, 21.  , 28.1   , 409.  , 1864.  , 0.00432073, 1.)],
      dtype=[('datetime', '<f8'), ('temperature', '<f8'), ('relative_humidity', '<f8'), 
             ('light', '<f8'), ('C02', '<f8'), ('humidity', '<f8'), ('occupancy', '<f8')])

target = 'occupancy'
nrows = len(data)
X = data.view('<f8').reshape(nrows, -1)[:, 1:-1]
y = data[target].astype(int)

这利用了每个字段长度为8个字节的事实。因此很容易将结构化数组转换为dtype <f8的普通数组。重塑使其成为具有相同行数的2D数组。切片会从数组中删除datetimeoccupancy列/字段。