我已经训练了TensorFlow CNN模型来执行文本分类。培训完成并评估没有问题,但当我在GCP上托管模型时发送预测时,我收到以下错误,我发现很难理解:
RuntimeError: Prediction failed: Error during model execution: AbortionError(code=StatusCode.INVALID_ARGUMENT, details="NodeDef mentions attr 'output_type' not in Op<name=ArgMax; signature=input:T, dimension:Tidx -> output:int64; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF]; attr=Tidx:type,default=DT_INT32,allowed=[DT_INT32, DT_INT64]>; NodeDef: predicted = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, _output_shapes=[[-1]], output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](fc/logits, predicted/dimension)
[[Node: predicted = ArgMax[T=DT_FLOAT, Tidx=DT_INT32, _output_shapes=[[-1]], output_type=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](fc/logits, predicted/dimension)]]")
有谁能告诉我这里出了什么问题?
答案 0 :(得分:0)
此错误的原因很可能是您用于培训和导出的tensorflow与Cloud ML Engine运行时之间的版本不匹配。
要解决此问题,您必须使用version
CLI部署gcloud
(Web控制台不支持选择运行时版本),同时指定模型的运行时版本。
假设您在本地使用tensorflow 1.8并以此方式训练和导出模型,您可以像使用控制台一样将模型上传到存储桶并创建模型。然后,为了创建版本,请执行以下操作:
gcloud ml-engine versions create <version_name> \
--model=<model_name> \
--origin=<gs://bucket_name> \
--runtime-version=1.8
这解决了我的问题。