Tensorflow:我的.pb文件有什么问题?我如何使其工作?

时间:2017-06-21 06:44:31

标签: python tensorflow

我想用我的模型对图像进行分类,模型是.pb文件。现在这里有一些问题。分类代码在这里:

# -*- coding: utf-8 -*-
from PIL import Image
import tensorflow as tf
import numpy as np
import re
import os

model_dir = '/home/vrview/tensorflow/example/char/tfrecords/try1/'
image_dir = '/home/vrview/tensorflow/example/char/test_abc/19.jpg'
class NodeLookup(object):
    def __init__(self,
                 label_lookup_path=None,
                 uid_lookup_path=None):
        if not label_lookup_path:
            label_lookup_path = os.path.join(
                model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
        if not uid_lookup_path:
            uid_lookup_path = os.path.join(
                model_dir, 'alex_number.txt')
        self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

    def load(self, label_lookup_path, uid_lookup_path):
        if not tf.gfile.Exists(uid_lookup_path):
            tf.logging.fatal('File does not exist %s', uid_lookup_path)
        if not tf.gfile.Exists(label_lookup_path):
            tf.logging.fatal('File does not exist %s', label_lookup_path)
        proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
        uid_to_human = {}
        p = re.compile(r'[n\d]*[ \S,]*')
        for line in proto_as_ascii_lines:
            parsed_items = p.findall(line)
            uid = parsed_items[0]
            human_string = parsed_items[2]
            uid_to_human[uid] = human_string
        node_id_to_uid = {}
        proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
        for line in proto_as_ascii:
            if line.startswith('  target_class:'):
                target_class = int(line.split(': ')[1])
            if line.startswith('  target_class_string:'):
                target_class_string = line.split(': ')[1]
                node_id_to_uid[target_class] = target_class_string[1:-2]
        node_id_to_name = {}
        for key, val in node_id_to_uid.items():
            if val not in uid_to_human:
                tf.logging.fatal('Failed to locate: %s', val)
            name = uid_to_human[val]
            node_id_to_name[key] = name    
        return node_id_to_name    
    def id_to_string(self, node_id):
        if node_id not in self.node_lookup:
            return ''
        return self.node_lookup[node_id]
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, 'alex_model22.pb'), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
def get_one_image():
    image = Image.open(image_dir)
    #plt.imshow(image)
    image = image.resize([56, 56])    
    image = np.array(image)
    return image    
image_array=get_one_image()  
image = tf.cast(image_array, tf.float32)
image_1 = tf.image.per_image_standardization(image) 
image_2 = tf.reshape(image_1, [1, 56, 56, 3])

create_graph()    
sess = tf.Session()  
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')   
predictions = sess.run(softmax_tensor, {'x-input': image_array})   
predictions = np.squeeze(predictions)

node_lookup = NodeLookup()
top_5 = predictions.argsort()[-5:][::-1]
for node_id in top_5:
    human_string = node_lookup.id_to_string(node_id)
    score = predictions[node_id]
    print('%s (score = %.5f)' % (human_string, score))   
sess.close()

现在我得到这样的错误:

Traceback (most recent call last):
  File "/home/vrview/tensorflow/example/char/tfrecords/try1/Alex_Read_pb.py", line 107, in <module>
    predictions = sess.run(softmax_tensor, {'x-input': image_array})
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/home/vrview/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 922, in _run
    + e.args[0])
TypeError: Cannot interpret feed_dict key as Tensor: The name 'x-input' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".

即使我知道错误,但我也不知道如何修复它。这里我将提供我的代码来创建.pb文件。我认为有些事情应该改变。

# coding=utf-8
from  color_1 import read_and_decode, get_batch, get_test_batch
from tensorflow.python.framework import graph_util
import tensorflow as tf

batch_size = 128
TRAIN_STEPS = 10000
crop_size = 56
REGULARAZTION_RATE=0.0001
MODEL_SAVE_PATH = "/home/vrview/tensorflow/example/char/tfrecords/try1/"
MODEL_NAME = "model.ckpt"   
def inference(input_tensor,train,regularizer):
    with tf.name_scope('conv1') as scope:
        kernel=tf.Variable(tf.truncated_normal([11,11,3,96],dtype=tf.float32,stddev=1e-1),name='weights1')
        conv=tf.nn.conv2d(input_tensor,kernel,[1,4,4,1],padding='SAME',name='con1')
        biases=tf.Variable(tf.constant(0.0,shape=[96],dtype=tf.float32),trainable=True,name='biases1')
        bias=tf.nn.bias_add(conv,biases)
        conv1=tf.nn.relu(bias,name=scope)
        lrn1=tf.nn.lrn(conv1,4,bias=1.0,alpha=0.001/9,beta=0.75,name='lrn1')
        pool1=tf.nn.max_pool(lrn1,ksize=[1,3,3,1],strides=[1,2,2,1],padding='VALID',name='pool1')

    with tf.name_scope('conv2') as scope:
        ......
        ......
    with tf.name_scope('conv3') as scope:
        ......
        ......
    with tf.name_scope('conv4') as scope:
        ......
        ......
    with tf.name_scope('conv5') as scope:
        .......

    pool_shape=pool5.get_shape().as_list()
    nodes=pool_shape[1]*pool_shape[2]*pool_shape[3]
    reshaped=tf.reshape(pool5,[pool_shape[0],nodes])

    with tf.variable_scope('layer6-fc1'):
        fc1_weights=tf.get_variable("weight6",[nodes,4096],initializer=tf.truncated_normal_initializer(stddev=0.1))
        if regularizer !=None:
            tf.add_to_collection('losses1',regularizer(fc1_weights))
        fc1_biases=tf.get_variable("biases6",[4096],initializer=tf.truncated_normal_initializer(0.1))

        fc1=tf.nn.relu(tf.matmul(reshaped,fc1_weights)+fc1_biases)
        if train:fc1=tf.nn.dropout(fc1,0.5)

    with tf.variable_scope('layer7-fc2'):
       ......
       ......
    with tf.variable_scope('layer8-fc3'):
        ......
        ......
        logit=tf.matmul(fc2,fc3_weights)+fc3_biases
    return logit
def train(batch_x, batch_y):
    image_holder = tf.placeholder(tf.float32, [batch_size, 56, 56, 3], name='x-input')
    label_holder = tf.placeholder(tf.float32, [batch_size], name='y-input')

    regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
    y = inference(image_holder, train,regularizer)
    global_step = tf.Variable(0, trainable=False)    
    def loss(logits, labels):
        labels = tf.cast(labels, tf.int64)
        y_conv = tf.nn.softmax(y, name='softmax')
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=labels, name='cross_entropy_per_example')

        cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
        tf.add_to_collection('losses', cross_entropy_mean)
        return tf.add_n(tf.get_collection('losses'), name='total_loss')
    loss = loss(y, label_holder)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
    saver = tf.train.Saver(max_to_keep=30)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)   
        f=open('/home/vrview/tensorflow/example/char/tfrecords/try1/loss.txt','w')
        graph_def=tf.get_default_graph().as_graph_def()
        for i in range(TRAIN_STEPS):
            image_batch, label_batch = sess.run([batch_x, batch_y])
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={image_holder: image_batch,                                                                                      label_holder: label_batch})
            if i % 100 == 0:
                format_str = ('After %d step,loss on training batch is: %g')
                print (format_str % (i, loss_value))
            f.write(str(i+1) +', loss_value:'+ str(loss_value) + '\n')
        output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def,['softmax','x-input'])
        with tf.gfile.GFile("/home/vrview/tensorflow/example/char/tfrecords/try1/alex_model22.pb",
                            "wb") as f:
            f.write(output_graph_def.SerializeToString())
            f.close()    
        coord.request_stop()
        coord.join(threads) 
def main(argv=None):
    image, label = read_and_decode('train.tfrecords')
    batch_image, batch_label = get_batch(image, label, batch_size, crop_size)   
    train(batch_image, batch_label)    
if __name__ == '__main__':
    tf.app.run()

也许我不应该保留这一点&#39; x-input&#39; 。但是我不知道应该保存什么。有人知道如何解决这个问题吗?非常感谢你。 enter image description here 图片是我的代码图表

0 个答案:

没有答案