从rpy2 Random Forest对象获取字段值

时间:2015-01-06 15:06:27

标签: python r random-forest rpy2

我正在尝试使用Python运行R Random Forest实现。我正在使用rpy2模块轻松完成这项工作。以下是随机生成数据的简单示例:

import numpy as np
from rpy2.robjects.numpy2ri import numpy2ri
from rpy2.robjects.packages import importr
from rpy2 import robjects as ro 

#create data
X np.random.rand(30,100)
#create y-values
y = np.random.randint(2, size=30)
X = numpy2ri(X)
y = ro.FactorVector(numpy2ri(y))
#build RF
model = rf.randomForest(X, y)

现在,如何从python中访问模型的所有字段?如何获得错误率或变量重要性?简而言之:

model$importance[,"MeanDecreaseGini"]

如何使用rpy2完成此操作?如何访问模型对象的所有字段?

2 个答案:

答案 0 :(得分:2)

您可以使用.rx访问字段:

>>> model.rx('importance')[0]
  <Matrix - Python:0x1126137e8 / R:0x10a292290>
[0.259480, 0.076463, 0.240162, ..., 0.049585, 0.249498, 0.043696]

答案 1 :(得分:2)

使用pandas,您可以指定列名,然后使用rpy2 / R interface将数据帧转换为保留字段名称的R对象。

import pandas as pd
import rpy2.robjects as robjects
import pandas.rpy.common as com
import numpy as np
r = robjects.r

r.library("randomForest")

# generate a pandas dataframe with random numbers
df = pd.DataFrame(data=np.random.rand(100, 30), columns=["a{}".format(i) for i in range(30)])
df["b"] = np.random.randint(2, size=100)

# create r objects
X = com.convert_to_r_dataframe(df.drop("b", axis=1))
Y = robjects.FactorVector(df.b)

# build rf model
rf = r.randomForest(X, Y)

# print Mean Decrease Gini and Field names
print rf.rx("importance")
print r.dimnames(rf[8])

返回

randomForest 4.6-7
Type rfNews() to see new features/changes/bug fixes.
$importance
    MeanDecreaseGini
a0          3.264841
a1          1.889741
a2          1.836287
a3          1.397774
a4          2.004300
a5          1.973436
a6          1.282584
a7          1.834799
a8          1.891645
a9          1.607779
a10         1.926996
a11         1.431277
a12         1.605571
a13         2.372562
a14         1.342930
a15         1.596201
a16         1.402425
a17         1.161261
a18         1.423914
a19         1.532494
a20         1.182701
a21         1.328816
a22         1.654255
a23         1.437174
a24         1.312123
a25         1.698160
a26         1.545838
a27         2.169778
a28         1.314767
a29         1.048250

...和你在R

中看到的字段名称