我使用以下代码将经过预训练的TensorFlow模型转换为.pb
文件
import tensorflow as tf
from argparse import ArgumentParser
def main():
parser = ArgumentParser()
parser.add_argument('--checkpoint', type=str,
dest='checkpoint',
help='dir or .ckpt file to load checkpoint from',
metavar='CHECKPOINT', required=True)
parser.add_argument('--model', type=str,
dest='model',
help='.meta for your model',
metavar='MODEL', required=True)
parser.add_argument('--out-path', type=str,
dest='out_path',
help='model output directory',
metavar='MODEL_OUT', required=True)
opts = parser.parse_args()
tf.reset_default_graph()
saver = tf.train.import_meta_graph(opts.model)
#builder = tf.saved_model.builder.SavedModelBuilder(opts.out_path)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
# Restore variables from disk.
saver.restore(sess, opts.checkpoint)
print("Model restored.")
#builder.add_meta_graph_and_variables(sess,['tfckpt2pb'],strip_default_attrs=False)
#builder.save()
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['Tower_0/parsing_fc/BiasAdd','Tower_0/parsing_rf_fc/BiasAdd','Tower_0/edge_rf_fc/BiasAdd'])
with tf.gfile.FastGFile(opts.out_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
print("pb Model saved.")
if __name__ == '__main__':
main()
然后导入带有Java代码的.pb
模型,如下所示
//import model
byte[] graphBytes = IOUtils.toByteArray(new FileInputStream(MODEL_PATH));
graph.importGraphDef(graphBytes);
//create session
try(Session session = new Session(graph)){
ConfigProto config = ConfigProto.newBuilder()
.setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true))
.build();
//get the output
Tensor<?> output = session.runner()
.setOptions(config.toByteArray())
.feed("Tower_0/strided_slice", imageTensor)
.fetch("Tower_0/parsing_fc/BiasAdd").run().get(0);
System.out.println(output);
}
然后错误为波纹管
//这是错误消息2020-04-06 12:13:50.269556:I tensorflow / core / platform / cpu_feature_guard.cc:142]您的CPU支持 TensorFlow二进制文件未编译使用的指令:AVX2 AVX512F FMA 2020-04-06 12:13:50.281419:I tensorflow / core / platform / profile_utils / cpu_utils.cc:94] CPU频率: 2600000000赫兹 2020-04-06 12:13:50.288167:I tensorflow / compiler / xla / service / service.cc:168] XLA服务 为平台主机初始化了0x7f968ded1db0(这不保证 将会使用XLA)。设备: 2020-04-06 12:13:50.288206:I tensorflow / compiler / xla / service / service.cc:176] StreamExecutor 设备(0):主机,默认版本 java.lang.IllegalArgumentException:无法为操作Tower_0 / strided_slice分配设备:{{node Tower_0 / strided_slice}}原为 明确分配给/ device:GPU:0,但可用设备为[ / job:本地主机/副本:0 /任务:0 /设备:CPU:0, / job:localhost /副本:0 /任务:0 /设备:XLA_CPU:0]。确保 设备规范是指有效的设备。请求的设备 似乎是GPU,但未启用CUDA。 [[Tower_0 / strided_slice]] 在org.tensorflow.Session.run(本机方法) 在org.tensorflow.Session.access $ 100(Session.java:48) 在org.tensorflow.Session $ Runner.runHelper(Session.java:326) 在org.tensorflow.Session $ Runner.run(Session.java:276)
我的pom配置
//here are maven configuration
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.15.0</version>
</dependency>
我不知道为什么会发生此错误,并且我无法解决问题。