如何处理.csv输入以用于Tensorflow服务批量转换?

时间:2020-08-04 14:00:53

标签: csv amazon-s3 tensorflow-serving amazon-sagemaker

信息: 我正在从S3存储桶中加载现有的受训model.tar.gz,并希望使用包含输入数据的.csv执行批处理转换。 data.csv的结构使其可以将其读入pandas DataFrame中,从而获得完整的预测输入行。

笔记:
  • 这是使用Python SDK在Amazon Sagemaker上完成的
  • BATCH_TRANSFORM_INPUT是data.csv的路径。
  • 我能够将内容加载到model.tar.gz中,并使用它们在Tensorflow上在我的本地计算机上进行推断,日志显示2020-08-04 13:35:01.123557: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: model version: 1},因此该模型似乎已经过训练并正确保存了。
  • data.csv与训练数据的格式完全相同,这意味着每个“预测”一行,其中该行中的所有列代表不同的功能。
  • 将参数策略更改为“ MultiRecord”会出现相同的错误
  • [s3中的路径]代替了真实路径,因为我不想透露任何存储桶信息。
  • TensorFlow ModelServer:2.0.0 + dev.sha.ab786af
  • TensorFlow库:2.0.2

特征为1-5的文件data.csv如下:

+------+-------------------------+---------+----------+---------+----------+----------+
| UNIT | TS                      | 1       | 2        | 3       | 4        | 5        |
+------+-------------------------+---------+----------+---------+----------+----------+
| 110  | 2018-01-01 00:01:00.000 | 1.81766 | 0.178043 | 1.33607 | 25.42162 | 12.85445 |
+------+-------------------------+---------+----------+---------+----------+----------+
| 110  | 2018-01-01 00:02:00.000 | 1.81673 | 0.178168 | 1.30159 | 25.48204 | 12.87305 |
+------+-------------------------+---------+----------+---------+----------+----------+
| 110  | 2018-01-01 00:03:00.000 | 1.8155  | 0.176242 | 1.38399 | 25.35309 | 12.47222 |
+------+-------------------------+---------+----------+---------+----------+----------+
| 110  | 2018-01-01 00:04:00.000 | 1.81530 | 0.176398 | 1.39781 | 25.18216 | 12.16837 |
+------+-------------------------+---------+----------+---------+----------+----------+
| 110  | 2018-01-01 00:05:00.000 | 1.81505 | 0.151682 | 1.38451 | 25.22351 | 12.41623 |
+------+-------------------------+---------+----------+---------+----------+----------+

inference.py当前看起来像:

def input_handler(data, context):
    import pandas as pd
    if context.request_content_type == 'text/csv':
        payload = pd.read_csv(data)
        instance = [{"dataset": payload}]
        return json.dumps({"instances": instance})
    else:
        _return_error(416, 'Unsupported content type "{}"'.format(context.request_content_type or 'Unknown'))

问题:

当以下代码在我的jupyter Notebook中运行时:

sagemaker_model = Model(model_data = '[path in s3]/savedmodel/model.tar.gz'),  
                        sagemaker_session=sagemaker_session,
                        role = role,
                        framework_version='2.0',
                        entry_point = os.path.join('training', 'inference.py')
                        )

tf_serving_transformer = sagemaker_model.transformer(instance_count=1,
                                                     instance_type='ml.p2.xlarge',
                                                     max_payload=1,
                                                     output_path=BATCH_TRANSFORM_OUTPUT_DIR,
                                                     strategy='SingleRecord')


tf_serving_transformer.transform(data=BATCH_TRANSFORM_INPUT, data_type='S3Prefix', content_type='text/csv')
tf_serving_transformer.wait()

该模型似乎已加载,但最终出现以下错误: 2020-08-04T09:54:27.415:[sagemaker logs]: MaxConcurrentTransforms=1, MaxPayloadInMB=1, BatchStrategy=SINGLE_RECORD 2020-08-04T09:54:27.503:[sagemaker logs]: [path in s3]/data.csv: ClientError: 400 2020-08-04T09:54:27.503:[sagemaker logs]: [path in s3]/data.csv: 2020-08-04T09:54:27.503:[sagemaker logs]: [path in s3]/data.csv: Message: 2020-08-04T09:54:27.503:[sagemaker logs]: [path in s3]/data.csv: { "error": "Failed to process element: 0 of 'instances' list. Error: Invalid argument: JSON Value: \"\" Type: String is not of expected type: float" }

更清楚的错误:

ClientError:400 消息:{“错误”:“无法处理'instances'列表的元素:0。错误:无效的参数:JSON值:“”类型:字符串不是预期的类型:float“}

如果我正确理解此错误,则我的数据的结构方式出了点问题,因此sagemaker无法将输入数据传递到TFS模型。我想inference.py中缺少一些“输入处理”。也许csv数据必须以某种方式转换为兼容的JSON,以便TFS使用它?在input_handler()中到底需要做什么?

感谢所有帮助,对于这个令人困惑的案例,我们深表歉意。如果需要其他信息,请询问,我们将竭诚为您提供帮助。

1 个答案:

答案 0 :(得分:1)

解决方案:通过使用参数header = False,index = False将数据帧另存为.csv来解决了问题。这使保存的csv不包含数据帧索引标签。 TFS接受仅带有浮点值(无标签)的干净.csv。我假设错误消息 Invalid参数:JSON值:“”类型:字符串不是预期的类型:float 指的是csv中的第一个单元格,如果csv是使用标签导出的,则它只是一个空单元格。当它得到一个空字符串而不是一个浮点值时,就会感到困惑。