冻结的CNN模型的Java Tensorflow推理问题

时间:2019-02-19 23:06:57

标签: java python tensorflow keras deep-learning

我在使用Java Tensorflow API时遇到一些问题。

基本上,我试图使用我在Python中训练的冻结模型来预测一些图像,但是我想对Java的Tensorflow进行这些推断,以便以后可以开发的某些应用程序使用。

我首先将Python模型导出为.pb文件,然后将其加载到Tensorflow中,并且可以将其用于推断目的,我在Python中对其进行了测试,并且没有任何问题。

然后,我尝试修改Java Tensorflow示例中提供的LabelImage.java示例,该示例可在GitHub(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java)上找到。我基本上修改了模型的路径以及将要使用的图像。在成功纠正了一些错误之后,该代码是可运行的,但是我遇到了以下错误:

Exception in thread "main" java.lang.UnsupportedOperationException: Generic conv implementation does not support grouped convolutions for now.
 [[{{node conv2d_1/convolution}} = Conv2D[T=DT_FLOAT, data_format="NHWC", dilations=[1, 1, 1, 1], padding="SAME", strides=[1, 1, 1, 1], use_cudnn_on_gpu=true, _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_1_0_0, conv2d_1/kernel)]]

总体来说,我在Java和Tensorflow中还是一个新手,我试图找到类似的错误,例如我遇到的错误,但没有发现任何有用的东西。我想知道错误是否试图告诉我当前的Java Tensorflow API不支持卷积,如果是这种情况,我会感到非常惊讶。无论如何,我对解决该问题的方法感到迷茫,希望有人可以帮助我找出解决办法。

一些细节:我在Keras上构建并训练了U-Net模型,并使用Gi​​tHub上某个用户的方法将训练后的Keras模型转换为.pb文件,该文件可以重新加载到Tensorflow上。并进行推断(用户:https://github.com/amir-abdi/keras_to_tensorflow)。此重新加载和推断部分在Python中可以完美运行(我已经确定对其进行了测试)。

此代码块中似乎发生了错误:

 private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
      // Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
      Tensor<Float> result =
          s.runner().feed("input_1", image).fetch("conv2d_24/Sigmoid").run().get(0).expect(Float.class)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    return result.copyTo(new float[1][nlabels])[0];
  }
}

此代码未更改,因为我说过我只是更改了指向模型和用于测试的示例图像的输入路径。我更改的确切部分可以在下面找到:

  public static void main(String[] args) throws Exception {
System.out.println("TensorFlow version: " + TensorFlow.version());

byte[] graphDef = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\test.pb"));
byte[] imageBytes = readAllBytesOrExit(Paths.get("C:\\Users\\joao_\\Documents\\GitHub\\Tensorflow-to-PB\\java_code\\src\\main\\resources\\02.png"));

try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
  float[] labelProbabilities = executeInceptionGraph(graphDef, image);
  int bestLabelIdx = maxIndex(labelProbabilities);
}

我希望这些信息足以理解问题。

1 个答案:

答案 0 :(得分:0)

好吧,最后我找到了自己问题的答案。

基本上,该错误与以下事实有关:我将图像馈送到没有适当大小的模型(我的图像为512x512,而我的模型仅获取256x256图像)。因此,我想问题是输入张量的尺寸不正确。

希望这篇文章对帮助有同样问题的人仍然有用。