我正在使用Apache Spark的ML库构建机器学习模型,我们假设使用RandomForestClassifier。
我按如下所示将数据集划分为训练和测试
rf = RandomForestClassifier(numTrees=10,featuresCol = "features",
labelCol = "label")
model= rf.fit(tr)
prediction = model.transform(test)
eval = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
eval.evaluate(prediction)
应用模型
public class FarmActivity extends AppCompatActivity {
ListView listView;
FirebaseDatabase database;
DatabaseReference ref;
ArrayList<String> list;
ArrayAdapter<String> adapter;
FieldInformation fieldInformation;
@Override
protected void onCreate(@Nullable Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_farm);
listView = (ListView )findViewById(R.id.listView);
database = FirebaseDatabase.getInstance();
ref = database.getReference("Fields");
list = new ArrayList<>();
adapter = new ArrayAdapter<String>(this, R.layout.user_field_info, R.id.userInfo, list);
fieldInformation = new FieldInformation();
ref.addValueEventListener(new ValueEventListener() {
@Override
public void onDataChange(DataSnapshot dataSnapshot) {
for(DataSnapshot ds:dataSnapshot.getChildren()){
fieldInformation = ds.getValue(FieldInformation.class);
list.add(fieldInformation.getFieldName().toString() + " " + "\n"+ fieldInformation.getCrop().toString() + " ");
}
listView.setAdapter(adapter);
}
@Override
public void onCancelled(DatabaseError databaseError) {
}
});
我的印象是,这给了我AUC的准确性。如何获得此模型的精度,召回率,F1和精度?
我的课程变量是二进制(0或1)。
答案 0 :(得分:0)
AUC是ROC曲线下的面积。与准确性无关,但根据我的观点,它更有用。更好地概述了模型的功能。 您需要的所有指标都在这里: https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#binary-classification 请注意,所有指标都是针对一个标签计算的(取决于您的真实肯定数是0还是1)。如果您的班级不平衡,并且您计算了主要班级的指标(假设为1分),那么您的结果可能会产生误导。因此,使用对于模型正确分类更为重要的标签。 请在使用指标之前先仔细阅读文档,以全面了解它们的含义。 干杯。
答案 1 :(得分:0)
您可以使用 MulticlassMetrics 来获得准确率和召回率。
predictionAndLabels = prediction.select("prediction","label").rdd
# Instantiate metrics objects
multi_metrics = MulticlassMetrics(predictionAndLabels)
precision_score = multi_metrics.weightedPrecision
recall_score = multi_metrics.weightedRecall
或者,您可以获取混淆矩阵并自行计算。
confusion_matrix = multi_metrics.confusionMatrix().toArray()