永久更新tensorflow-java中的变量(在推理期间)

时间:2018-04-12 16:30:37

标签: java python variables tensorflow

我已经使用python-tensorflow训练了一个模型,我想在java-tensorflow中进行推理。我已将训练好的模型/图表加载到Java中。在此之后,我想永久更新图表中的一个变量。我知道python中的tf.variable.load(value,session)函数可以用来更新变量的值。我想知道Java中是否有类似的方法。

到目前为止,我已尝试过以下内容。

// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)

但上述行仅在同一行执行的updatedTensorValue次调用期间variableName使用fetch

g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();

上述代码不是更新值,而是尝试将相同的变量添加到图表中,因此它会抛出异常。

永久更新图表中变量的另一种方法是,我会在所有feed(variableName,updatedTensorValue)次调用期间始终调用fetch方法。我会在几个实例上运行推理代码,所以我想知道这个额外的feed调用所需的额外时间。

由于

1 个答案:

答案 0 :(得分:2)

在TensorFlow中执行大多数操作的方法是执行操作。您在尝试运行Assign操作时走在正确的轨道上,但却错误地调用了它,因为要分配的value不是"属性" Assign操作,而是输入张量。 (请参阅原始definition of the operation,但不可否认,除非您熟悉TensorFlow内部,否则定义可能并不容易理解。

但是,您不需要在Java中向图形添加操作来执行此操作。相反,你可以完全按照Python中的tf.Variable.load执行 - 执行tf.Variable.initializer操作,输入输入值。

例如,请考虑使用Python构建的以下图表:

import tensorflow as tf

var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()

# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables:         ', init.name)
print('myvar.initializer:          ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)

现在,我们在Java中复制Python var.load()的行为,使用类似的东西为变量赋值3.0:

try (Tensor<Float> newValue = Tensors.create(3.0f)) {
  s.runner()
    .feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
    .addTarget("myvar/Assign")           // myvar.initializer.name
    .run();
}

希望有所帮助。