我正在将pyspark.ml.regression.RandomForestRegressionModel
保存到HDFS:
from pyspark.ml.regression import RandomForestRegressor, RandomForestRegressionModel
regresor = RandomForestRegressor(
maxDepth=16,
numTrees=100,
seed=SEED,
impurity="variance"
)
model = regresor.fit(trainingData)
model.save("random_forest")
列出内容将显示以下输出:
[username@node ~]$ hdfs dfs -du -h /user/username/random_forest
70.7 M 212.2 M /user/username/random_forest/data
509 1.5 K /user/username/random_forest/metadata
19.6 K 58.7 K /user/username/random_forest/treesMetadata
每个目录都有大量的镶木地板文件。我想知道在每个目录中存储了什么。训练数据的完整副本是否存储在/user/username/random_forest/data
中?
我担心目录的总大小与训练数据的大小和模型的复杂性有何关系。
答案 0 :(得分:1)
模型不包含训练数据的副本,但是大小取决于模型参数(例如树的数量)。让我们看一个只有两棵树的小例子:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("/tmp/sample_libsvm_data.txt")
# Automatically identify categorical features, and index them.
# Set maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer = VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
data = featureIndexer.transform(data)
# Train a RandomForest model.
rf = RandomForestRegressor(featuresCol="indexedFeatures", numTrees=2)
# Train model.
model = rf.fit(data)
model.write().overwrite().save('/tmp/randForest')
这将生成以下文件:
/tmp/randForest/data
:
/tmp/randForest/treesMetadata
:
/tmp/randForest/metadata
:
model.debugString
的输出是(我添加了缩进):
RandomForestRegressionModel (uid=RandomForestRegressor_7fa55d0ee30d) with 2 trees
Tree 0 (weight 1.0):
If (feature 462 <= 62.5)
Predict: 0.0
Else (feature 462 > 62.5)
Predict: 1.0
Tree 1 (weight 1.0):
If (feature 405 <= 21.0)
If (feature 623 <= 253.5)
Predict: 0.0
Else (feature 623 > 253.5)
Predict: 1.0
Else (feature 405 > 21.0)
If (feature 425 <= 19.0)
Predict: 1.0
Else (feature 425 > 19.0)
Predict: 0.0
让我们检查文件:
spark.read.parquet("/tmp/randForest/data/part-00000-c08a27a9-a7d3-47d5-a50d-55cd42ee12f9-c000.snappy.parquet").show(truncate=False)
#Output
+------+------------------------------------------------------------------------------------------------------------+
|treeID| nodeData |
+------+------------------------------------------------------------------------------------------------------------+
|0 |[0, 0.5483870967741935, 0.2476586888657648, [93.0, 51.0, 51.0], 0.2476586888657648, 1, 2, [462, [62.5], -1]]|
|0 |[1, 0.0, 0.0, [42.0, 0.0, 0.0], -1.0, -1, -1, [-1, [], -1]] |
|0 |[2, 1.0, 0.0, [51.0, 51.0, 51.0], -1.0, -1, -1, [-1, [], -1]] |
+------+------------------------------------------------------------------------------------------------------------+
data
中的文件包含创建DecisionTreeRegressionModel所需的参数(节点,分割条件...)
spark.read.parquet("/tmp/randForest/treesMetadata/part-00000-414aff59-3fac-44d9-8dee-4a334ce41bce-c000.snappy.parquet").show(truncate=False)
#Output
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+
|treeID|metadata |weights|
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+
|0 |{"class":"org.apache.spark.ml.regression.DecisionTreeRegressionModel","timestamp":1564155036724,"sparkVersion":"2.4.0","uid":"dtr_d963ec77a5ad","paramMap":{},"defaultParamMap":{"checkpointInterval":10,"minInfoGain":0.0,"seed":1366634793,"maxMemoryInMB":256,"maxDepth":5,"impurity":"variance","cacheNodeIds":false,"labelCol":"label","predictionCol":"prediction","maxBins":32,"featuresCol":"features","minInstancesPerNode":1}}|1.0 |
+------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------+
treesMetadata
中的文件包含每个单独的DecisionTreeRegressionModel的参数,这些参数由您设置,或者在您调用RandomForestRegressor对象的fit()
方法时默认设置。
sc.textFile("/tmp/randForest/metadata/part-00000").collect()
#Output:
['{"class":"org.apache.spark.ml.regression.RandomForestRegressionModel","timestamp":1564155036257,"sparkVersion":"2.4.0","uid":"RandomForestRegressor_7fa55d0ee30d","paramMap":{"numTrees":2,"featuresCol":"indexedFeatures"},"defaultParamMap":{"maxBins":32,"impurity":"variance","checkpointInterval":10,"numTrees":20,"labelCol":"label","maxDepth":5,"predictionCol":"prediction","featuresCol":"features","minInstancesPerNode":1,"featureSubsetStrategy":"auto","subsamplingRate":1.0,"seed":2502083311556356884,"minInfoGain":0.0,"cacheNodeIds":false,"maxMemoryInMB":256},"numFeatures":692,"numTrees":2}']
metadata
中的文件包含RandomForestRegressorModel的参数,这些参数由您设置,或者在调用RandomForestRegressor对象的fit()
方法时默认设置。