我想在本地运行模型。我正在尝试从网络课程中训练和预测模型:
已使用上述代码训练了模型。这是YOLO对象检测模型,用于检测使用tf.estimator
建造的飞机。使用提供的代码成功完成了培训,但是我不知道如何推断模型。
import tensorflow as tf
# DATA
DATA = './samples/airplane_sample.png'
# Model: This directory contains saved_model.pb and variables
SAVED_MODEL_DIR = './1559196417/'
def decode_image():
img_bytes = tf.read_file(DATA)
decoded = tf.image.decode_image(img_bytes, channels=3)
return tf.cast(decoded, dtype=tf.uint8)
def main1():
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], SAVED_MODEL_DIR)
img = decode_image()
result = sess.run(['classes'], feed_dict={'input': img})
print(result)
def main2():
model = tf.contrib.predictor.from_saved_model(SAVED_MODEL_DIR)
pred = model({'image_bytes': [decode_image()], 'square_size': [tf.placeholder(tf.int32)]})
print(pred)
if __name__ == "__main__":
main2()
以上是我编写的代码,但无效。甚至我都不知道有什么问题。输入类型不正确? API不当?你能给我一些建议吗?
答案 0 :(得分:0)
也许这会起作用:
import tensorflow as tf
import cv2
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile('./1559196417/saved_model.pb', 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
image = cv2.imread('./samples/airplane_sample.png')
rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
rgb_img_expanded = np.expand_dims(rgb_img, axis=0)
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
classes = detection_graph.get_tensor_by_name('classes:0')
result = sess.run([classes],feed_dict={image_tensor: rgb_img_expanded})
答案 1 :(得分:0)
首先在python之外的终端中运行saved_model_cli show --all --dir SAVED_MODEL_DIR
,以检查保存的模型并检查其是否具有正确的标签,输入和输出。从那里开始,需要花费一些精力才能从API中获取必要的信息。
def extract_tensors(signature_def, graph):
output = dict()
for key in signature_def:
value = signature_def[key]
if isinstance(value, tf.TensorInfo):
output[key] = graph.get_tensor_by_name(value.name)
return output
def extract_tags(signature_def, graph):
output = dict()
for key in signature_def:
output[key] = dict()
output[key]['inputs'] = extract_tensors(
signature_def[key].inputs, graph)
output[key]['outputs'] = extract_tensors(
signature_def[key].outputs, graph)
return output
with tf.Session(graph=tf.Graph()) as session:
serve = tf.saved_model.load(
session, tags=['serve'], export_dir=SAVED_MODEL_DIR)
tags = extract_tags(serve.signature_def, session.graph)
model = tags['serving_default']
您可以从那里尝试print(model['inputs'], model['outputs'])
来查看导出了哪些输入和输出,如果它们与saved_model_cli
一致,如果您需要另一个标签,只需将其替换为serving_default
。