使用genfromtxt输入scikit-learn拟合函数的问题

时间:2014-01-08 19:15:23

标签: csv numpy scikit-learn

我正在尝试genfromtxt读取csv文件,然后使用RandomForestClassifier。我最后两次使用genfromtxt;一次读取功能然后获得正确的格式。 此尝试的代码如下:    导入csv     导入numpy为np

data = np.genfromtxt('plants.csv',dtype=float, delimiter=',', names=True)
feature_names = np.array(data.dtype.names)
feature_names = feature_names[[ 0,1,2,3,4]] 

data = np.genfromtxt('plants.csv',dtype=float, delimiter=',', skip_header=1)
plants_X = data[:, [0,1,2,3,4]] 
_y = np.ravel(data[:,[5]]) #Return a flattened array required by scikit-learn fit for 2nd argument

from sklearn.ensemble import RandomForestClassifier 
clf = RandomForestClassifier( n_estimators = 10, random_state = 33)
clf = clf.fit(plants_X, plants_y)

print feature_names, '\n', clf.feature_importances_

print feature_names, '\n', clf.feature_importances_

当我使用genfromtxt和“names = True选项”数据时“读入的格式不符合我的预期!

“([(31.194181,0.0,0.0,0.0,1.0,1.0),        (12.0,0.0,0.0,1.0,0.0,1.0),(18.0,1.0,0.0,1.0,0.0,0.0),        (31.194181,0.0,0.0,0.0,1.0,0.0)],        ...       dtype = [('A','

我想从文件中获取功能名称而不读取它两次!

感谢您的协助!

Ps:向“Cyborg”致敬我已经走到了这一步!

1 个答案:

答案 0 :(得分:2)

我建议使用pandas。 您可以使用pandas.read_csv获取带有列名的pandas数据帧。您需要将数据转换为numpy数组,以将其传递给scikit-learn,通过。