随机砍伐森林模型评估策略(混淆矩阵,准确性和精度_recall_fscore

时间:2019-03-28 15:19:52

标签: machine-learning random-forest amazon-sagemaker

我正在使用AWS sagemker随机砍伐森林算法来检测异常。

import boto3
import sagemaker

containers = {
    'us-west-2': '174872318107.dkr.ecr.us-west-2.amazonaws.com/randomcutforest:latest',
    'us-east-1': '382416733822.dkr.ecr.us-east-1.amazonaws.com/randomcutforest:latest',
    'us-east-2': '404615174143.dkr.ecr.us-east-2.amazonaws.com/randomcutforest:latest',
    'eu-west-1': '438346466558.dkr.ecr.eu-west-1.amazonaws.com/randomcutforest:latest',
    'ap-southeast-1':'475088953585.dkr.ecr.ap-southeast-1.amazonaws.com/randomcutforest:latest'
    }
region_name = boto3.Session().region_name
container = containers[region_name]

session = sagemaker.Session()

rcf = sagemaker.estimator.Estimator(
    container,
    sagemaker.get_execution_role(),
    output_path='s3://{}/{}/output'.format(bucket, prefix),
    train_instance_count=1,
    train_instance_type='ml.c5.xlarge',
    sagemaker_session=session)

rcf.set_hyperparameters(
    num_samples_per_tree=200,
    num_trees=250,
    feature_dim=1,
    eval_metrics =["accuracy", "precision_recall_fscore"])

s3_train_input = sagemaker.session.s3_input(
    s3_train_data,
    distribution='ShardedByS3Key',
    content_type='application/x-recordio-protobuf')

rcf.fit({'train': s3_train_input})

(引自-> https://aws.amazon.com/blogs/machine-learning/use-the-built-in-amazon-sagemaker-random-cut-forest-algorithm-for-anomaly-detection/

使用上面的代码来训练模型,但没有找到评估模型的方法。 部署模型后如何获得准确性和F分数。

1 个答案:

答案 0 :(得分:0)

为了获得评估指标,您需要在培训期间提供一个称为“测试”的额外渠道。测试通道必须包含标记的数据。官方文档https://docs.aws.amazon.com/sagemaker/latest/dg/randomcutforest.html中对此进行了说明:

  

Amazon SageMaker Random Cut Forest支持训练和测试数据通道。可选的测试通道用于计算标记数据的准确性,准确性,召回率和F1得分指标。训练和测试数据内容类型可以是application / x-recordio-protobuf或text / csv格式。对于测试数据,当使用text / csv格式时,必须将内容指定为text / csv; label_size = 1,其中每一行的第一列代表异常标签:“ 1”表示异常数据点,“ 0”表示异常数据点。正常数据点。您可以使用“文件”模式或“管道”模式来训练对格式化为recordIO-wrapped-protobuf或CSV格式的数据的RCF模型

     

还请注意...测试通道仅支持S3DataDistributionType = FullyReplicated

谢谢

朱利奥(Julio)