如何通过TensorFlowInferenceInterface.java提供布尔占位符?

时间:2017-03-24 14:03:21

标签: java android tensorflow keras

我尝试通过 Java Tensorflow API 启动 Keras Tensorflow 图表的培训。
除了标准输入图像占位符之外,此图还包含' keras_learning_phase' 占位符,需要提供布尔值值。

问题是,TensorFlowInferenceInterface中没有布尔值值的方法 - 你只能用 float double , int byte 值。

显然,当我尝试通过此代码将 int 传递给此张量时:

inferenceInterface.fillNodeInt("keras_learning_phase",  
                               new int[]{1}, new int[]{0});

我得到了

  

tensorflow_inference_jni.cc:207推理期间出错:内部:   int32类型的输出0与声明的输出类型bool不匹配   node _recv_keras_learning_phase_0 = _Recvclient_terminated = true,   recv_device =" /作业:本地主机/复制:0 /任务:0 / CPU:0&#34 ;,   send_device =" /作业:本地主机/复制:0 /任务:0 / CPU:0&#34 ;,   send_device_incarnation = 4742451733276497694,   tensor_name =" keras_learning_phase",tensor_type = DT_BOOL,   _device =" /作业:本地主机/复制:0 /任务:0 / CPU:0"

有没有办法绕过它?
也许有可能以某种方式将图中的占位符节点明确转换为常量
或者也许最初可以避免在图表中创建占位符

2 个答案:

答案 0 :(得分:5)

TensorFlowInferenceInterface类本质上是完整TensorFlow Java API的便利包装器,它支持布尔值。

你可以向TensorFlowInferenceInterface添加一个方法来做你想做的事。与fillNodeInt类似,您可以添加以下内容(请注意,TensorFlow中的布尔值表示为一个字节):

public void fillNodeBool(String inputName, int[] dims, bool[] src) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; ++i) {
    b[i] = src[i] ? 1 : 0;
  }
  addFeed(inputName, Tensor.create(DatType.BOOL, mkDims(dims), ByteBuffer.wrap(b)));
}

希望有所帮助。如果有效,我建议您回馈TensorFlow代码库。

答案 1 :(得分:0)

这是answer之后ash的补充,因为Tensorflow API已经发生了一些变化。用这个对我有用:

public void feed(String inputName, boolean[] src, long... dims) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; i++) {
    b[i] = src[i] ? (byte) 1 : (byte) 0;
  }
  addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}