在Google Cloud Platform上使用AI平台/ ML引擎进行Scikit Learn的在线预测失败

时间:2019-05-14 05:21:40

标签: python machine-learning scikit-learn google-cloud-platform google-cloud-ml

我正在尝试使用scikit-learn框架在带有AI平台(ML-Engine)的GCP上的Python中运行在线预测。这些示例看起来非常简单明了,但是我陷入了似乎无法在任何地方解决的错误消息。

Error: RuntimeError: Prediction failed: Exception during sklearn prediction: unorderable types: NoneType() < int()

我正在使用一个非常干净的数据集,其中用于运行预测的数据是一个列表列表,每个元素都是浮点数且没有缺失值。使用Google SDK时,我也一直在努力正确设置JSON格式,因此,如果该错误与该错误有关,则可以一次解决两个问题。

import numpy as np
import json

predict_X = predict_features_scaled.tolist()
with open('json_file.json', 'w', encoding='utf-8') as f:
    for x in predict_X:
        json.dump(x, f)
        f.write("\n")

!gcloud ai-platform predict --model my_model \
    --version v1 \
    --json-instances json_file.json

这将产生:

RuntimeError: Prediction failed: Exception during sklearn prediction: unorderable types: NoneType() < int()

类似地:

import googleapiclient.discovery

data = predict_features_scaled.tolist()
instances_to_predict = data[0:4]

PROJECT_ID = 'my_project'
MODEL_NAME = 'my_model'
VERSION_NAME = 'v1'

service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID, MODEL_NAME, VERSION_NAME)

response = service.projects().predict(
    name=name,
    body={'instances': instances_to_predict}
).execute()

if 'error' in response:
    raise RuntimeError(response['error'])
else:
  print(response['predictions'])

这将产生:

RuntimeError: Prediction failed: Exception during sklearn prediction: unorderable types: NoneType() < int()

模型已正确训练,当我以instances_to_predict的格式加载模型时不会造成任何问题。

local_model_path = os.path.join(local_path, 'model.pkl')
my_model = pickle.load(open(local_model_path, 'rb'))
my_model.predict(instances_to_predict)

输出:array([False, True, False, True])

instances_to_predict看起来像这样(浮点列表的列表):

[[0.3814189029296227, 0.2377409536692593, 0.3796558634510401, 0.23155885471898197, 0.4419029078963379, 0.3589350346604503, 0.18090440487347703, 0.30526838966202785, 0.26748531767218375, 0.39448188711036236, 0.09433279015028068, 0.10654614568599717, 0.07289261650096596, 0.05236851837324756, 0.08192541727572492, 0.11970138492504584, 0.050404040404040396, 0.19018753551809056, 0.03754150954015874, 0.08091842516203031, 0.3141230878690858, 0.22414712153518124, 0.2952836296628319, 0.1650855289028706, 0.2795350987254837, 0.18463971437164672, 0.14824281150159743, 0.34982817869415805, 0.12063867534003553, 0.1997245179063361], [0.32367835676085005, 0.4998309097057827, 0.3354294796489531, 0.19189819724284196, 0.65324248692055, 0.45616833323109013, 0.31794751640112456, 0.33593439363817096, 0.5915643352909771, 0.471988205560236, 0.13166757197175447, 0.2580887553041018, 0.10446214013098998, 0.060231827537644875, 0.2708297922969711, 0.27268903776248987, 0.08777777777777777, 0.3061185830649744, 0.2315810209939776, 0.21074996890676176, 0.28744219139096416, 0.5575692963752665, 0.2768564171522487, 0.1481517892253244, 0.7147196724559202, 0.3583064101444635, 0.27004792332268374, 0.522680412371134, 0.41119653065247386, 0.41492850583759683], [0.2630981116001703, 0.22353736895502196, 0.2588625526915901, 0.1411664899257688, 0.6483757148071543, 0.27335746273234773, 0.13889409559512647, 0.16297216699801192, 0.414308595835558, 0.34035383319292345, 0.052942241535397415, 0.1623939179632249, 0.03811902181595438, 0.021830488720540605, 0.27157765917666654, 0.10077507735752697, 0.06502525252525251, 0.21992801666982384, 0.2832217031575393, 0.05476555698354131, 0.20064034151547494, 0.25, 0.18058668260371535, 0.09014942980731418, 0.5734002509410288, 0.14340600168815673, 0.1508785942492013, 0.28903780068728524, 0.3134240094618569, 0.13341204250295152], [0.7624118510104596, 0.3422387554954345, 0.7484624421256305, 0.6525980911983033, 0.37607981506265986, 0.33163609594503407, 0.39268978444236174, 0.4985089463220676, 0.18045915643352917, 0.1027801179443979, 0.4271229404309251, 0.08508663366336634, 0.41831974744381095, 0.323867478025693, 0.1373355542713397, 0.235136840207889, 0.10040404040404038, 0.40841068384163665, 0.19879551978386895, 0.0691720907093404, 0.8132337246531484, 0.3158315565031984, 0.8022809900891479, 0.6382225717656311, 0.32179885095423627, 0.32522241949723973, 0.3030351437699681, 0.7780068728522336, 0.26473487088507786, 0.11629279811097992]]```

0 个答案:

没有答案