如何使用Lambda和API网关部署由AWS Sagemaker创建的乳腺癌预测终端节点?

时间:2018-10-01 09:43:17

标签: amazon-web-services aws-lambda aws-api-gateway amazon-sagemaker

我正在尝试使用aws lambda和API网关在Amazon sagemanker上部署现有的乳腺癌预测模型。我已经按照以下网址提供了官方文档。

https://aws.amazon.com/blogs/machine-learning/call-an-amazon-sagemaker-model-endpoint-using-amazon-api-gateway-and-aws-lambda/

我在“ predicted_label”处遇到类型错误。

    <div class="container-fluid bg">
        <div class="row">
            <div class="col-md-4 col-sm-6 col-xs-12">
            </div>
            <div class="col-md-4 col-sm-6 col-xs-12">
                <form class="form-container">
                    <div class="form-group">
                        <label for="exampleInputEmail1">Email address</label>
                        <input type="email" class="form-control" id="exampleInputEmail1" placeholder="Email">
                    </div>
                    <div class="form-group">
                        <label for="exampleInputPassword1">Password</label>
                        <input type="password" class="form-control" id="exampleInputPassword1" placeholder="Password">
                    </div>

                    <button type="submit" class="btn btn-default">Submit</button>
                </form> 
            </div>
            <div class="col-md-4 col-sm-6 col-xs-12">
            </div>
        </div>
    </div>
and css file is:

    .bg{
    background: url('bg.jpg');
    height: 100%;
    width: 100%;
    background-repeat: no-repeat;   
}
.form-container{
    border: 1px solid;
}

请让我知道是否有人可以解决此问题。谢谢。

1 个答案:

答案 0 :(得分:3)

通过用print(type(result))打印结果类型,您可以看到其字典。现在您可以看到键名是“ score”,而不是您为pred提供的“ predicted_label”。因此,将其替换为

pred = int(result['predictions'][0]['score'])

我认为这可以解决您的问题。

这是我的lambda函数:

import os
import io
import boto3
import json
import csv

# grab environment variables
ENDPOINT_NAME = os.environ['ENDPOINT_NAME']
runtime= boto3.client('runtime.sagemaker')

def lambda_handler(event, context):
   print("Received event: " + json.dumps(event, indent=2))

   data = json.loads(json.dumps(event))
   payload = data['data']
   print(payload)

   response = runtime.invoke_endpoint(EndpointName=ENDPOINT_NAME,
                                      ContentType='text/csv',
                                      Body=payload)
   #print(response)
   print(type(response))
   for key,value in response.items():
       print(key,value)
   result = json.loads(response['Body'].read().decode())
   print(type(result))
   print(result['predictions'])
   pred = int(result['predictions'][0]['score'])
   print(pred)
   predicted_label = 'M' if pred == 1 else 'B'

   return predicted_label