如何在预测后显示tensorflow类名

时间:2017-11-23 01:49:42

标签: python machine-learning tensorflow

当我预测使用tensorflow时,如何获取与其预测关联的类的名称?现在它只返回一组概率。这是我用来预测图像的代码:

class Prediction:

def __init__(self, filename, filepath, image_size = 128, number_channels = 3):

    self.x_batch = []
    self.images = []
    self.image_size = image_size
    self.number_channels = number_channels

    self.image = cv2.imread(filename)

    self.modelpath = filepath
    self.modelfilepath = filepath + '/train-model.meta'

    self.sess = tf.Session()
    self.graph = None
    self.y_pred = None


def resize_image(self):
    self.image = cv2.resize(self.image, (self.image_size, self.image_size), cv2.INTER_LINEAR)
    self.images.append(self.image)
    self.images = np.array(self.images, dtype=np.uint8)
    self.images = self.images.astype('float32') 
    self.images = np.multiply(self.images, 1.0 / 255.0)
    self.x_batch = self.images.reshape(1, self.image_size, self.image_size, self.number_channels)


def restore_model(self):

    saver = tf.train.import_meta_graph(self.modelfilepath)
    saver.restore(self.sess, tf.train.latest_checkpoint(self.modelpath))

    self.graph = tf.get_default_graph()

    self.y_pred = self.graph.get_tensor_by_name("y_pred:0")


def predict_image(self):
    x = self.graph.get_tensor_by_name("x:0")
    y_true = self.graph.get_tensor_by_name("y_true:0")
    y_test_images = np.zeros((1, 2))

    feed_dict_testing = {x: self.x_batch, y_true: y_test_images}
    result = self.sess.run(self.y_pred, feed_dict=feed_dict_testing)
    return result

感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

查看您的训练代码有助于了解您如何测量对抗真实值的精确度。也就是说,您需要一个可以像这样使用的标签文件 -

        predictions = self.sess.run(self.y_pred, feed_dict=feed_dict_testing)

        # Format predicted classes for display
        #   use np.squeeze to convert the tensor to a 1-d vector of probability values
        predictions = np.squeeze(predictions)

        top_k = predictions.argsort()[-5:][::-1]  # Getting the indicies of the top 5 predictions

        #   read the class labels in from the label file
        f = open(labelPath, 'rb')
        lines = f.readlines()
        labels = [str(w).replace("\n", "") for w in lines]
        print("")
        print ("Image Classification Probabilities")
        #   Output the class probabilites in descending order
        for node_id in top_k:
            human_string = filter_delimiters(labels[node_id])
            score = predictions[node_id]
            print('{0:s} (score = {1:.5f})'.format(human_string, score))

直接来自tensorflow examples再培训开始。希望这有帮助