如何使用Tensorflow

时间:2017-07-24 12:39:51

标签: tensorflow

我使用tensorflow构建了多个DNN和conVNN,现在我可以达到很好的准确性。现在我的问题是如何在实际例子中使用这个训练有素的网络。 我是计算机视觉的一个例子,我如何使用该模型对新图片进行分类?我可以生成像convNN.exe那样的图像作为输入参数通过分类结果输出吗?

1 个答案:

答案 0 :(得分:1)

一旦您对该模型进行了培训,您应该通过添加类似于

的代码将其保存在某处
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
      sess, [tag_constants.SERVING],
      signature_def_map={
           'predict_images':
               prediction_signature,
           signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
               classification_signature,
      },
      legacy_init_op=legacy_init_op)
builder.save()

然后,您可以使用Tensorflow serving通过运行

使用高性能C ++服务器来为您的模型提供服务
bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server \
    --port=9000 --model_name=mnist \
    --model_base_path=/tmp/mnist_model/

当然,修改模型的代码。您需要实施客户端;这是MNIST here的一个例子。客户的胆量如下:

def do_inference(hostport, work_dir, concurrency, num_tests):
  """Tests PredictionService with concurrent requests.
  Args:
    hostport: Host:port address of the PredictionService.
    work_dir: The full path of working directory for test data set.
    concurrency: Maximum number of concurrent requests.
    num_tests: Number of test images to use.
  Returns:
    The classification error rate.
  Raises:
    IOError: An error occurred processing test data set.
  """
  test_data_set = mnist_input_data.read_data_sets(work_dir).test
  host, port = hostport.split(':')
  channel = implementations.insecure_channel(host, int(port))
  stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
  result_counter = _ResultCounter(num_tests, concurrency)
  for _ in range(num_tests):
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'mnist'
    request.model_spec.signature_name = 'predict_images'
    image, label = test_data_set.next_batch(1)
    request.inputs['images'].CopyFrom(
        tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
    result_counter.throttle()
    result_future = stub.Predict.future(request, 5.0)  # 5 seconds
    result_future.add_done_callback(
        _create_rpc_callback(label[0], result_counter))
  return result_counter.get_error_rate()


def main(_):
  if FLAGS.num_tests > 10000:
    print('num_tests should not be greater than 10k')
    return
  if not FLAGS.server:
    print('please specify server host:port')
    return
  error_rate = do_inference(FLAGS.server, FLAGS.work_dir,
                            FLAGS.concurrency, FLAGS.num_tests)
  print('\nInference error rate: %s%%' % (error_rate * 100))

if __name__ == '__main__':
  tf.app.run()

当然,这是用Python编写的,但是如果你想创建一个二进制可执行文件,你没有理由不能使用其他语言(例如Go或C ++)。