How to import a model using pb file in tensorflow

时间:2018-07-25 04:48:27

标签: tensorflow image-segmentation tensorboard semantic-segmentation

I have been assigned a task to fine tune deeplab V3+ using tensorflow and python. For that purpose I download the frozen model from deeplab github page. ! Sample Image

I downloaded this file. Then I searched through the web on how to create a model using these files Sample Image

There are method only to create model using .ckpt files and .meta files but i don't have any of those file

There are only methods available to create graph from the .pb file. I don't know what to do after creating a graph using the .pb file. I to import the frozen model using these files. Thank you in advance

1 个答案:

答案 0 :(得分:1)

这应该有效

import os
from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import tensorflow as tf

INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)


def run(image):
    width, height = image.size
    resize_ratio = 1.0 * INPUT_SIZE / max(width, height)
    target_size = (int(resize_ratio * width), int(resize_ratio * height))
    resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
    batch_seg_map = sess.run(
        OUTPUT_TENSOR_NAME,
        feed_dict={INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
    seg_map = batch_seg_map[0]
    return resized_image, seg_map

input_image = Image.open('test.jpg')
resized_im, seg_map = run(input_image)
fig = plt.figure()
fig.add_subplot(1, 2, 1)
plt.imshow(resized_im)
fig.add_subplot(1, 2, 2)
plt.imshow(np.ma.masked_equal(seg_map, 0))