Tensorflow集线器标签和导出

时间:2018-06-06 01:00:28

标签: python tensorflow tensorflow-serving tensorflow-estimator tensorflow-hub

我很困惑标签应该如何在集线器中工作,以及如何在导出时使用它们。如何在我的图表的火车部分上训练并导出服务?

我有以下代码:

def user_module_fn(foo, bar):
    x = tf.sparse_placeholder(tf.float32, shape[-1, 32], name='name')
    y = something(x)
    hub.add_signature(name='my_name', input={"x": x}, output={"default", y})

module_spec = hub.create_module_spec(module_spec_fn, tags_and_args=[
   (set(), {"foo": foo, "bar": bar}),
   ({"train"}, {"foo": foo, "bar": baz})
 ])

m = hub.Module(module_spec, name="my_name", trainable=True, tags={"train"})

hub.register_for_export(m, "my_name")

我的问题如下:因为我正在将模块m实例化为tags={'train'},我认为我正在使用正确的模块进行培训。这是否意味着我导出标有train的标签?如何使用train进行培训,set()(默认情况下)进行投放?

1 个答案:

答案 0 :(得分:1)

在最好(即最简单)的情况下,您的模块根本不需要任何标签,即当同一条TensorFlow图适合模块的所有预期用途时。为此,只需离开tagstags_and_args即可获取默认值(一组空标记)。

如果同一模块需要多个版本的图表,例如,在训练模式中应用辍学的培训版本,以及使辍学成为无操作的推理版本,则需要标签。您通常会看到像

这样的代码
def module_fn(training):
  inputs = tf.placeholder(dtype=tf.float32, shape=[None, 50])
  layer1 = tf.layers.fully_connected(inputs, 200)
  layer1 = tf.layers.dropout(layer1, rate=0.5, training=training)
  layer2 = tf.layers.fully_connected(layer1, 100)
  outputs = dict(default=layer2)
  hub.add_signature(inputs=inputs, outputs=outputs)

...

tags_and_args = [(set(), {"training": False}),
                 ({"train"}, {"training": True})]
module_spec = hub.create_module_spec(module_fn, tags_and_args)

创建模块规范为所有提供的参数dicts运行module_fn,并存储 all 在幕后构建它们的图形。当您从该规范制作模块然后将其导出时,它将包含已创建的所有图形版本,并使用相应的字符串集进行标记。

tags=...的{​​{1}}参数仅控制在当前图形中使用哪个不同的图形版本,例如,调用m = hub.Module(...)时(即应用于输入)。它不会限制m写出的内容。