启用批处理时,带有TF服务的TFX“批量输出张量具有0个维”错误

时间:2020-03-25 18:58:40

标签: tensorflow tensorflow-serving

我正在使用TFX训练张量流模型,这是我的EvalInputReceiver的代码段

  serialized_tf_example = tf.placeholder(
      dtype=tf.string, shape=[None], name='input_example_tensor')

  features = tf.parse_example(serialized_tf_example, raw_feature_spec)

  transformed_features = tf_transform_output.transform_raw_features(features)

  receiver_tensors = {'examples': serialized_tf_example}

  features.update(transformed_features)

  return tfma.export.EvalInputReceiver(
      features=features,
      receiver_tensors=receiver_tensors,
      labels=transformed_features[_transformed_name(_LABEL_KEY)

训练和推送模型时,我是通过GRPC调用模型,如下所示:

    example = example_pb2.Example()
    examples = tf.make_tensor_proto([example.SerializeToString()], dtype=tf.string)
request = predict_pb2.PredictRequest()
    request.model_spec.name = "model"
    request.inputs["examples"].CopyFrom(examples)
    self.prediction_service.Predict(request)

当没有使用--enable-batching运行tensorflow-serving时,这很好用。但是,当这些参数启用时:

num_batch_threads { value: 32 }
max_batch_size { value: 128 }
batch_timeout_micros { value: 1000 }
max_enqueued_batches { value: 1000000 }

启用批处理后调用模型时,出现以下错误:“批处理输出张量的尺寸为0”。

如果我实际发送发送多个请求(例如使用

tf.make_tensor_proto([example.SerializeToString()]*2, dtype=tf.string)

有什么想法可以解决这个问题?

0 个答案:

没有答案