我已经使用python-tensorflow训练了一个模型,我想在java-tensorflow中进行推理。我已将训练好的模型/图表加载到Java中。在此之后,我想永久更新图表中的一个变量。我知道python中的tf.variable.load(value,session)
函数可以用来更新变量的值。我想知道Java中是否有类似的方法。
到目前为止,我已尝试过以下内容。
// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)
但上述行仅在同一行执行的updatedTensorValue
次调用期间variableName
使用fetch
。
g.opBuilder("Assign",variableName).setAttr("value",updatedTensorValue).build();
上述代码不是更新值,而是尝试将相同的变量添加到图表中,因此它会抛出异常。
永久更新图表中变量的另一种方法是,我会在所有feed(variableName,updatedTensorValue)
次调用期间始终调用fetch
方法。我会在几个实例上运行推理代码,所以我想知道这个额外的feed
调用所需的额外时间。
由于
答案 0 :(得分:2)
在TensorFlow中执行大多数操作的方法是执行操作。您在尝试运行Assign
操作时走在正确的轨道上,但却错误地调用了它,因为要分配的value
不是"属性" Assign
操作,而是输入张量。 (请参阅原始definition of the operation,但不可否认,除非您熟悉TensorFlow内部,否则定义可能并不容易理解。
但是,您不需要在Java中向图形添加操作来执行此操作。相反,你可以完全按照Python中的tf.Variable.load
执行 - 执行tf.Variable.initializer
操作,输入输入值。
例如,请考虑使用Python构建的以下图表:
import tensorflow as tf
var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()
# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables: ', init.name)
print('myvar.initializer: ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)
现在,我们在Java中复制Python var.load()
的行为,使用类似的东西为变量赋值3.0:
try (Tensor<Float> newValue = Tensors.create(3.0f)) {
s.runner()
.feed("myvar/initial_value", newVal) // myvar.initializer.inputs[1].name
.addTarget("myvar/Assign") // myvar.initializer.name
.run();
}
希望有所帮助。