pyspark中的记录丢失功能

时间:2018-02-28 01:01:10

标签: machine-learning pyspark cross-validation

pyspark中是否有内置的日志丢失功能?

我有一个包含列的pyspark数据框: 概率,rawPrediction,标签

我希望使用平均对数损失来评估这些预测。

2 个答案:

答案 0 :(得分:1)

AFAICT没有直接存在这样的功能。但是给定一个PySpark数据帧df,其中的列名称为问题中的列,可以显式计算平均日志丢失:

import pyspark.sql.functions as f

df = (
    df.withColumn(
        'logloss'
        , -f.col('label')*f.log(f.col('probability')) - (1.-f.col('label'))*f.log(1.-f.col('probability'))
    )
)

logloss = df.agg(f.mean('logloss').alias('ll')).collect()[0]['ll']

我假设label是数字(即0或1),而probability代表模型的预测。 (不确定rawPrediction可能意味着什么。)

答案 1 :(得分:0)

我假设您正在使用Spark ML,并且您的数据帧是拟合估算器的输出。在这种情况下,您可以使用以下函数来计算日志损失。

from pyspark.sql.types import FloatType
from pyspark.sql import functions as F

def log_loss(df):

  # extract predicted probability
  get_first_element = F.udf(lambda v:float(v[1]), FloatType())
  df = df.withColumn("preds", get_first_element(F.col("probability")))

  # calculate negative log-likelihood
  y1, y0 = F.col("label"), 1 - F.col("label")
  p1, p0 = F.col("preds"), 1 - F.col("preds")
  nll = -(y1*F.log(p1) + y0*F.log(p0))

  # aggregate
  return df.agg(F.mean(nll)).collect()[0][0]