将预测结果保存为CSV

时间:2016-01-18 21:50:07

标签: python numpy pandas scikit-learn

我将sklearn回归模型的结果存储到varibla预测中。

prediction = regressor.predict(data[['X']])
print(prediction)

预测输出的值如下所示

[ 266.77832991  201.06347505  446.00066136  499.76736079  295.15519906
  214.50514991  422.1043505   531.13126879  287.68760191  201.06347505
  402.68859792  478.85808879  286.19408248  192.10235848]

然后我尝试使用to_csv函数将结果保存到本地CSV文件:

prediction.to_csv('C:/localpath/test.csv')

但我得到的错误是:

AttributeError: 'numpy.ndarray' object has no attribute 'to_csv'

我正在使用Pandas / Numpy / SKlearn。关于基本修复的任何想法?

3 个答案:

答案 0 :(得分:19)

你可以使用熊猫。 如上所述,numpy数组没有to_csv函数。

import numpy as np
import pandas as pd
prediction = pd.DataFrame(predictions, columns=['predictions']).to_csv('prediction.csv')

如果您希望您的值在行或列中,请添加“.T”。

答案 1 :(得分:10)

您可以使用numpy.savetxt功能:

numpy.savetxt('C:/localpath/test.csv',prediction, ,delimiter=',')

要加载CSV文件,您可以使用numpy.genfromtxt功能:

numpy.genfromtxt('C:/localpath/test.csv', delimiter=',')

答案 2 :(得分:1)

这是一个非常详细的解决方案案例,但是您甚至可以在生产中使用它。

首先保存模型

joblib.dump(regressor, "regressor.sav")

按顺序保存列

pd.DataFrame(X_train.columns).to_csv("feature_list.csv", index = None)

保存火车组的数据类型

pd.DataFrame(X_train.dtypes).reset_index().to_csv("data_types.csv", index = None)

再次使用:

feature_list = pd.read_csv("feature_list.csv")
feature_list = pd.Index(list(feature_list["0"]))

add_cols = list(feature_list.difference(X_test.columns))

drop_cols = list(X_test.columns.difference(feature_list))

for col in add_cols:
    X_test[col] = np.nan

for col in drop_cols:
    X_test = X_test.drop(col, axis = 1)

# reorder columns
X_test = X_test[feature_list]

types = pd.read_csv("data_types.csv")
for i in range(len(types)):
    X_test[types.iloc[i,0]] = X_test[types.iloc[i,0]].astype(types.iloc[i,1])

做出预测

regressor = joblib.load("regressor.sav")
predictions = regressor.predict(X_test)

保存预测结果

res = pd.DataFrame(predictions)
res.index = X_test.index # its important for comparison
res.columns = ["prediction"]
res.to_csv("prediction_results.csv")

享受端到端的模型/预测保护程序代码!