如何使用Tensorflow 1.0 Java API创建/初始化变量

时间:2017-03-15 15:24:21

标签: java python tensorflow

我试图移植这行Python代码:

my_var = tf.Variable(3, name="input_a")

到Java。我能够以这种方式使用tf.constant执行此操作:

graph.opBuilder("Const", name)
        .setAttr("dtype", tensorVal.dataType())
        .setAttr("value", tensorVal).build()
        .output(0);

我尝试了与变量类似的方法:

graph.opBuilder("Variable", name)
        .setAttr("dtype", tensorVal.dataType())
        .setAttr("shape", shape)
        .build()
        .output(0);

但是我收到了这个错误:

Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value input_a
[[Node: input_a/_2 = _Send[T=DT_INT32, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_5_input_a", _device="/job:localhost/replica:0/task:0/cpu:0"](input_a)]]

我想我需要使用值设置一个特殊属性,或者我需要稍后对其进行初始化。但我找不到方法。

我计划对大多数其他tf方法(here我当前的努力)做同样的事情。所以我想了解如何自己提出答案。例如,通过查看此Python源代码:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/variable_scope.py https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/state_ops.py

我怀疑我应该分配"初始化程序"属性,但java API或初始化方法中没有Initializer接口。尚未实施?我是tensorflow和Python的新手。

1 个答案:

答案 0 :(得分:2)

与您有同样的需求,我使用张量流的assign节点将值赋给我的变量。首先,您需要按照您的方式定义节点,然后需要使用相应的值添加此节点。然后我在我的图表中稍后引用这个新分配的节点,这样就不会引发错误java.lang.IllegalStateException: Attempting to use uninitialized value

我使用GraphBuilder类扩展了Graph功能,并添加了这些必需的类:

class GraphBuilder(g: Graph ) {
  def variable(name: String, dataType: DataType, shape: Shape): Output = {
    g.opBuilder("Variable", name)
      .setAttr("dtype", dataType)
      .setAttr("shape", shape)
      .build()
      .output(0)
  }

  def assign(value: Output, variable: Output): Output = {
      graph.opBuilder("Assign", "Assign/" + variable.op().name()).addInput(variable).addInput(value).build().output(0)
  }
}

val WValue = Array.fill(numFeatures)(Array.fill(hiddenDim)(0.0))
val W = builder.variable("W", DataType.DOUBLE, Shape.make(numFeatures, hiddenDim))
val W_init = builder.assign(builder.constant("Wval", WValue), W)

assign 节点将在每个前向传递中为您的变量分配预设值,因此它也不适合训练。但无论如何,从这篇文章看来,您似乎需要添加依赖项,因为默认情况下JAVA API不提供训练节点:https://github.com/tensorflow/tensorflow/issues/5518