java.lang.IllegalArgumentException:无法分配用于操作的设备

时间:2020-04-06 12:07:33

标签: tensorflow

我使用以下代码将经过预训练的TensorFlow模型转换为.pb文件

    import tensorflow as tf
    from argparse import ArgumentParser

    def main():
        parser = ArgumentParser()
        parser.add_argument('--checkpoint', type=str,
                            dest='checkpoint',
                            help='dir or .ckpt file to load checkpoint from',
                            metavar='CHECKPOINT', required=True)
        parser.add_argument('--model', type=str,
                            dest='model',
                            help='.meta for your model',
                            metavar='MODEL', required=True)
        parser.add_argument('--out-path', type=str,
                            dest='out_path',
                            help='model output directory',
                            metavar='MODEL_OUT', required=True)
        opts = parser.parse_args()
        tf.reset_default_graph()
        saver = tf.train.import_meta_graph(opts.model)
        #builder = tf.saved_model.builder.SavedModelBuilder(opts.out_path)
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
            # Restore variables from disk.
            saver.restore(sess, opts.checkpoint)
            print("Model restored.")
            #builder.add_meta_graph_and_variables(sess,['tfckpt2pb'],strip_default_attrs=False)
            #builder.save()
            constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['Tower_0/parsing_fc/BiasAdd','Tower_0/parsing_rf_fc/BiasAdd','Tower_0/edge_rf_fc/BiasAdd'])

            with tf.gfile.FastGFile(opts.out_path, mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
                    print("pb Model saved.")

    if __name__ == '__main__':
        main()     

然后导入带有Java代码的.pb模型,如下所示

                  //import model
                  byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(MODEL_PATH));

                  graph.importGraphDef(graphBytes);


                  //create session
                  try(Session session = new Session(graph)){

                      ConfigProto config = ConfigProto.newBuilder()
                              .setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true))
                              .build();
                       //get the output
                      Tensor<?> output = session.runner()
                              .setOptions(config.toByteArray())
                              .feed("Tower_0/strided_slice", imageTensor)
                              .fetch("Tower_0/parsing_fc/BiasAdd").run().get(0);
                      System.out.println(output);
                  }

然后错误为波纹管

//这是错误消息2020-04-06 12:13:50.269556:I tensorflow / core / platform / cpu_feature_guard.cc:142]您的CPU支持 TensorFlow二进制文件未编译使用的指令:AVX2 AVX512F FMA 2020-04-06 12:13:50.281419:I tensorflow / core / platform / profile_utils / cpu_utils.cc:94] CPU频率: 2600000000赫兹 2020-04-06 12:13:50.288167:I tensorflow / compiler / xla / service / service.cc:168] XLA服务 为平台主机初始化了0x7f968ded1db0(这不保证 将会使用XLA)。设备: 2020-04-06 12:13:50.288206:I tensorflow / compiler / xla / service / service.cc:176] StreamExecutor 设备(0):主机,默认版本 java.lang.IllegalArgumentException:无法为操作Tower_0 / strided_slice分配设备:{{node Tower_0 / strided_slice}}原为 明确分配给/ device:GPU:0,但可用设备为[ / job:本地主机/副本:0 /任务:0 /设备:CPU:0, / job:localhost /副本:0 /任务:0 /设备:XLA_CPU:0]。确保 设备规范是指有效的设备。请求的设备 似乎是GPU,但未启用CUDA。 [[Tower_0 / strided_slice]] 在org.tensorflow.Session.run(本机方法) 在org.tensorflow.Session.access $ 100(Session.java:48) 在org.tensorflow.Session $ Runner.runHelper(Session.java:326) 在org.tensorflow.Session $ Runner.run(Session.java:276)

我的pom配置

       //here are maven configuration
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.15.0</version>
        </dependency>
        <dependency>
          <groupId>org.tensorflow</groupId>
          <artifactId>libtensorflow_jni_gpu</artifactId>
          <version>1.15.0</version>
        </dependency> 

我不知道为什么会发生此错误,并且我无法解决问题。

0 个答案:

没有答案