我应该如何在TensorFlow SavedModel中存储元数据?

时间:2019-01-19 20:28:11

标签: tensorflow tensorflow-serving

我们使用不同的配置训练了模型的许多变化,并要求对输入进行不同的预处理(其中,预处理在TensorFlow之外进行)。我想将我们的模型导出为SavedModels,并且我想我们将拥有一个API服务器,该服务器将提供对模型的访问并处理预处理并使用config与TensorFlow服务器进行对话,该配置将通过TensorFlow从模型元数据中检索服务器。模型元数据可能被构造为JSON,或者可能使用协议缓冲区。我不清楚这方面有哪些最佳做法。特别是,MetaInfoDef协议缓冲区具有三个不同的字段,这些字段似乎设计用于保存元数据(meta_graph_versionany_infotags)。但是除了tags字段之外,我找不到任何示例。

// User specified Version string. Can be the name of the model and revision,
// steps this model has been trained to, etc.
string meta_graph_version = 1;

[...]

// A serialized protobuf. Can be the time this meta graph is created, or
// modified, or name of the model.
google.protobuf.Any any_info = 3;

// User supplied tag(s) on the meta_graph and included graph_def.
//
// MetaGraphDefs should be tagged with their capabilities or use-cases.
// Examples: "train", "serve", "gpu", "tpu", etc.
// These tags enable loaders to access the MetaGraph(s) appropriate for a
// specific use-case or runtime environment.
repeated string tags = 4;

(尽管我不确定是否可以使用TensorFlow服务的客户端API以相同的方式检索这三个字段?)

2 个答案:

答案 0 :(得分:0)

使用客户端API(REST)提取元数据的命令如下所示

获取 http://host:port/v1/models/ $ {MODEL_NAME} [/ versions / $ {MODEL_VERSION}] /元数据

/ versions / $ {MODEL_VERSION}是可选的。如果省略,则会在响应中返回最新版本的模型元数据。

您可以在链接https://www.tensorflow.org/tfx/serving/api_rest/ =>模型元数据API中找到更多详细信息

答案 1 :(得分:0)

@gmr, 通过 tf.add_to_collection 将原型添加到集合中,以及 builder.add_meta_graph_and_variables 可以解决您的问题。

相同的代码如下:

# Mention the path below where you want the model to be stored
export_dir = "/usr/local/google/home/abc/Jupyter_Notebooks/export"

tf.gfile.DeleteRecursively(export_dir)

tf.reset_default_graph()

# Check below for other ways of adding Proto to Collection
tf.add_to_collection("my_proto_collection", "my_proto_serialized")

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session() as session:
  builder.add_meta_graph_and_variables(
      session,
      tags=[tf.saved_model.tag_constants.SERVING])
  builder.save()

其他将原型添加到集合中的方法的代码如下所示:

tf.add_to_collection("your_collection_name", str(your_proto))

any_buf = any_pb2.Any()

tf.add_to_collection("your_collection_name",
         any_buf.Pack(your_proto))

.pb文件save_model.pb保存在您提到的路径(export_dir)中,如下所示:

{   # (tensorflow.SavedModel) size=89B
  saved_model_schema_version: 1
  meta_graphs: {    # (tensorflow.MetaGraphDef) size=85B
    meta_info_def: {    # (tensorflow.MetaGraphDef.MetaInfoDef) size=29B
      stripped_op_list: {   # (tensorflow.OpList) size=0B
      } # meta_graphs[0].meta_info_def.stripped_op_list
      tags    : [ "serve" ] # size=5
      tensorflow_version    : "1.13.1"  # size=9
      tensorflow_git_version: "unknown" # size=7
    }   # meta_graphs[0].meta_info_def
    graph_def: {    # (tensorflow.GraphDef) size=4B
      versions: {   # (tensorflow.VersionDef) size=2B
        producer     : 23
      } # meta_graphs[0].graph_def.versions
    }   # meta_graphs[0].graph_def
    collection_def: {   # (tensorflow.MetaGraphDef.CollectionDefEntry) size=46B
      key  : "my_proto_collection"  # size=19
      value: {  # (tensorflow.CollectionDef) size=23B
        bytes_list: {   # (tensorflow.CollectionDef.BytesList) size=21B
          value: [ "my_proto_serialized" ]  # size=19
        }   # meta_graphs[0].collection_def[0].value.bytes_list
      } # meta_graphs[0].collection_def[0].value
    }   # meta_graphs[0].collection_def[0]
  } # meta_graphs[0]
}