我使用tensorflow构建了多个DNN和conVNN,现在我可以达到很好的准确性。现在我的问题是如何在实际例子中使用这个训练有素的网络。 我是计算机视觉的一个例子,我如何使用该模型对新图片进行分类?我可以生成像convNN.exe那样的图像作为输入参数通过分类结果输出吗?
答案 0 :(得分:1)
一旦您对该模型进行了培训,您应该通过添加类似于
的代码将其保存在某处builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
然后,您可以使用Tensorflow serving通过运行
使用高性能C ++服务器来为您的模型提供服务bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \
--port=9000 --model_name=mnist \
--model_base_path=/tmp/mnist_model/
当然,修改模型的代码。您需要实施客户端;这是MNIST here的一个例子。客户的胆量如下:
def do_inference(hostport, work_dir, concurrency, num_tests):
"""Tests PredictionService with concurrent requests.
Args:
hostport: Host:port address of the PredictionService.
work_dir: The full path of working directory for test data set.
concurrency: Maximum number of concurrent requests.
num_tests: Number of test images to use.
Returns:
The classification error rate.
Raises:
IOError: An error occurred processing test data set.
"""
test_data_set = mnist_input_data.read_data_sets(work_dir).test
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for _ in range(num_tests):
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
request.model_spec.signature_name = 'predict_images'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
result_counter.throttle()
result_future = stub.Predict.future(request, 5.0) # 5 seconds
result_future.add_done_callback(
_create_rpc_callback(label[0], result_counter))
return result_counter.get_error_rate()
def main(_):
if FLAGS.num_tests > 10000:
print('num_tests should not be greater than 10k')
return
if not FLAGS.server:
print('please specify server host:port')
return
error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
FLAGS.concurrency, FLAGS.num_tests)
print('\nInference error rate: %s%%' % (error_rate * 100))
if __name__ == '__main__':
tf.app.run()
当然,这是用Python编写的,但是如果你想创建一个二进制可执行文件,你没有理由不能使用其他语言(例如Go或C ++)。