在Java

时间:2017-05-11 14:54:38

标签: java tensorflow

我已经在Python中训练了TensorFlow模型,并希望在Java代码中使用它。训练模型是通过以下代码完成的:

def input_fn():
    features = {'a': tf.constant([[1],[2]]),
                'b': tf.constant([[3],[4]]) }
    labels = tf.constant([0, 1])
    return features, labels

feature_a = tf.contrib.layers.sparse_column_with_integerized_feature("a", bucket_size=10)
feature_b = tf.contrib.layers.sparse_column_with_integerized_feature("b", bucket_size=10)
feature_columns = [feature_a, feature_b]

model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns)
model.fit(input_fn=input_fn, steps=10)

现在我想保存此模型以在Java中使用它。似乎export_savedmodel是新的/首选的保存方式,所以我尝试了:

feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(feature_columns)
serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
model.export_savedmodel('export', serving_input_fn, as_text=True)

这会生成一个已保存的模型,可以使用

从Java加载
model = SavedModelBundle.load(dir, "serve");
model.session().runner()
    .feed("input_example_tensor", input)
    .fetch("linear/binary_logistic_head/predictions/probabilities")
    .run();

现在有一个问题:input_example_tensor应该是包含Strings / byte []的Tensor,但是Java中还不支持它(参见:Tensor.java#88"抛出新的UnsupportedOperationException" )。据我所知,它想要一个String的原因是build_parsing_serving_input_fn想要解析序列化的示例协议缓冲区。

也许不同的serving_input_fn会做得更好。 input_fn_utils.build_default_serving_input_fn看起来很有希望,但我没有让它发挥作用。

如果我称之为:

features_dict = {'a':feature_a, 'b':feature_b}
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)

我得到" AttributeError:' _SparseColumnIntegerized'对象没有属性' get_shape'"

如果我称之为:

features = {'a': tf.constant([[1],[2]]),
            'b': tf.constant([[3],[4]]) }
serving_input_fn = input_fn_utils.build_default_serving_input_fn(features)

我得到" ValueError:' Const:0'不是有效的范围名称"。

使用input_fn_utils.build_default_serving_input_fn的正确方法是什么?我找不到任何使用它的例子。

0 个答案:

没有答案