如何使用c ++ api加载由SavedModel生成的Tensorflow模型?

时间:2019-03-06 02:18:40

标签: c++ tensorflow

我已将模型保存在python中,下面是我的代码:

B_cur = tf.placeholder(tf.float32, [1, img_height, img_width, 
        3*num_input_imgs])
B_mid=tf.placeholder(tf.float32,[batch_size,img_height,img_width,3])
F = tf.placeholder(tf.float32, [batch_size, img_height//2, img_width//2, 
net_channel//2])
H = tf.placeholder(tf.float32, [1, batch_size, img_height//2, img_width//2, 
net_channel])
with tf.variable_scope("model") as scope:
    (L, F_o, H_o, A) = model.ovd(B_cur,B_mid, F, H, net_channel)
model_path = './model-613104'
builder=tf.saved_model.builder.SavedModelBuilder('./SavedModel/')
inputs={'B_curr':tf.saved_model.utils.build_tensor_info(B_cur),'B_mid':tf.saved_model.utils.build_tensor_info(B_mid),'F':tf.saved_model.utils.build_tensor_info(F),'H':tf.saved_model.utils.build_tensor_info(H)}
outputs={'L':tf.saved_model.utils.build_tensor_info(L),'F_o':tf.saved_model.utils.build_tensor_info(F_o),'H_o':tf.saved_model.utils.build_tensor_info(H_o)}
signature=tf.saved_model.signature_def_utils.build_signature_def(inputs,outputs,'ovd_signature')

我知道该方法在python中加载它,但是如何在c ++中加载它呢? 或者在哪里可以获得有关如何在c ++中加载保存的模型的详细说明? 下面是在Python中加载模型的代码:

graph=tf.saved_model.loader.load(sess,['ovd_graph'],'./SavedModel/')
signature_key='ovd_signature'
#extract signature from graph
signature=graph.signature_def
#extract tensor name from signature   
B_curr_name=signature[signature_key].inputs['B_curr'].name
B_mid_name = signature[signature_key].inputs['B_mid'].name
F_name=signature[signature_key].inputs['F'].name
H_name=signature[signature_key].inputs['H'].name
L_name=signature[signature_key].outputs['L'].name
F_o_name=signature[signature_key].outputs['F_o'].name
H_o_name=signature[signature_key].outputs['H_o'].name
#find tensor according to names
B_cur=sess.graph.get_tensor_by_name(B_curr_name)
B_mid=sess.graph.get_tensor_by_name(B_mid_name)
F=sess.graph.get_tensor_by_name(F_name)
H=sess.graph.get_tensor_by_name(H_name)
L=sess.graph.get_tensor_by_name(L_name)
F_o=sess.graph.get_tensor_by_name(F_o_name)
H_o=sess.graph.get_tensor_by_name((H_o_name))

以上是我使用python加载模型的代码,我需要所有这些特征图(B_rur,B_mid,F,H)作为输入,因此在保存模型时保存它们,在加载模型时加载它们,我不知道有没有更简单的方法?如何在c ++中翻译它们?

0 个答案:

没有答案