我已经在Python中训练了TensorFlow模型,并希望在Java代码中使用它。训练模型是通过以下代码完成的:
def input_fn():
features = {'a': tf.constant([[1],[2]]),
'b': tf.constant([[3],[4]]) }
labels = tf.constant([0, 1])
return features, labels
feature_a = tf.contrib.layers.sparse_column_with_integerized_feature("a", bucket_size=10)
feature_b = tf.contrib.layers.sparse_column_with_integerized_feature("b", bucket_size=10)
feature_columns = [feature_a, feature_b]
model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns)
model.fit(input_fn=input_fn, steps=10)
现在我想保存此模型以在Java中使用它。似乎export_savedmodel是新的/首选的保存方式,所以我尝试了:
feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
model.export_savedmodel('export', serving_input_fn, as_text=True)
这会生成一个已保存的模型,可以使用
从Java加载model = SavedModelBundle.load(dir, "serve");
model.session().runner()
.feed("input_example_tensor", input)
.fetch("linear/binary_logistic_head/predictions/probabilities")
.run();
现在有一个问题:input_example_tensor应该是包含Strings / byte []的Tensor,但是Java中还不支持它(参见:Tensor.java#88"抛出新的UnsupportedOperationException" )。据我所知,它想要一个String的原因是build_parsing_serving_input_fn想要解析序列化的示例协议缓冲区。
也许不同的serving_input_fn会做得更好。 input_fn_utils.build_default_serving_input_fn
看起来很有希望,但我没有让它发挥作用。
如果我称之为:
features_dict = {'a':feature_a, 'b':feature_b}
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)
我得到" AttributeError:' _SparseColumnIntegerized'对象没有属性' get_shape'"
如果我称之为:
features = {'a': tf.constant([[1],[2]]),
'b': tf.constant([[3],[4]]) }
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)
我得到" ValueError:' Const:0'不是有效的范围名称"。
使用input_fn_utils.build_default_serving_input_fn
的正确方法是什么?我找不到任何使用它的例子。