scala 如何从 Spark ml lib 中的交叉验证中获得准确率、召回率和 ROC?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/41714698/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-10-22 09:00:59  来源:igfitidea点击:

How to get accuracy precision, recall and ROC from cross validation in Spark ml lib?

scalaapache-sparkmachine-learningprecision-recall

提问by user3309479

I am using Spark 2.0.2. I am also using the "ml" library for Machine Learning with Datasets. What I want to do is run algorithms with cross validation and extract the mentioned metrics (accuracy, precision, recall, ROC, confusion matrix). My data labels are binary.

我正在使用 Spark 2.0.2。我还使用“ml”库进行数据集机器学习。我想要做的是运行具有交叉验证的算法并提取提到的指标(准确度、精确度、召回率、ROC、混淆矩阵)。我的数据标签是二进制的。

By using the MulticlassClassificationEvaluator I can only get the accuracy of the algorithm by accessing "avgMetrics". Also, by using the BinaryClassificationEvaluator I can get the area under ROC. But I cannot use them both. So, is there a way that I can extract all of the wanted metrics?

通过使用 MulticlassClassificationEvaluator,我只能通过访问“avgMetrics”来获得算法的准确性。此外,通过使用 BinaryClassificationEvaluator,我可以获得 ROC 下的区域。但我不能同时使用它们。那么,有没有一种方法可以提取所有想要的指标?

回答by ShuoshuoFan

Have tried to use MLlib to evaluate your result.

曾尝试使用 MLlib 来评估您的结果。

I've transformed the dataset to RDD, then used MulticlassMetricsin MLlib

我已经改变了数据集RDD,然后用MulticlassMetrics在MLlib

You can see a demo here: Spark DecisionTreeExample.scala

你可以在这里看到一个演示:Spark DecisionTreeExample.scala

private[ml] def evaluateClassificationModel(
      model: Transformer,
      data: DataFrame,
      labelColName: String): Unit = {
    val fullPredictions = model.transform(data).cache()
    val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0))
    val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0))
    // Print number of classes for reference.
    val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
      case Some(n) => n
      case None => throw new RuntimeException(
        "Unknown failure when indexing labels for classification.")
    }
    val accuracy = new MulticlassMetrics(predictions.zip(labels)).accuracy
    println(s"  Accuracy ($numClasses classes): $accuracy")
  }

回答by Darshan

You can follow the official Evaluation Metrics guideprovided by Apache Spark. The document has provided all the Evaluation Metrics including

您可以按照Apache Spark 提供的官方评估指标指南进行操作。该文件提供了所有评估指标,包括

  • Precision (Positive Predictive Value), Recall (True Positive Rate), F-measure, Receiver Operating Characteristic (ROC), Area Under ROC Curve, Area Under Precision-Recall Curve.
  • 精度(正预测值)、召回率(真阳性率)、F 测量、接收器操作特性 (ROC)、ROC 曲线下面积、精确召回曲线下面积。

Here is the link : https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html

这是链接:https: //spark.apache.org/docs/latest/mllib-evaluation-metrics.html