我已经训练了一个带有tensorflow的GAN,现在我想在我的c ++项目中使用它。 我的GAN是这样的(输入和输出都是图像):
image = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 3, SIZE, SIZE])
input = 2*((tf.cast(image, tf.float32)/255.)-.5) #0~255 to -1~1
output = GAN(input) #GAN is my network including many modules
我注意到有一个saved_model
工具可以将我的模型保存到saved_model.pb
中,我可以直接在C ++中使用它。
我这样做的代码是这样的:
tensor_input_info = tf.saved_model.utils.build_tensor_info(input)
tensor_output_info = tf.saved_model.utils.build_tensor_info(output)
gan_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'image': tensor_input_info},
outputs={'result': tensor_output_info},
method_name='gan'
)
)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
session, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'my_gan':gan_signature
},
legacy_init_op=legacy_init_op)
builder.save()
这里我不确定dict中的键。在这段代码中,我使用“image”作为我输入的关键,但我不知道它是否正确。即使我成功了saved_model.pb
。
现在我不知道该怎么做,我怎样才能在我的C ++项目中使用它?