我是深度学习领域的初学者,并且我正在尝试测量保存的tensorflow模型[cifar10上的Resnet](使用低级tensorflow API的Pb文件)的准确性。 我正在尝试完成以下任务:
问题在于准确性是否符合法律规定,因此,我确信我在会话运行中做错了什么。
代码如下:
import tensorflow as tf
import numpy as np
import keras
from keras.datasets import cifar10
def load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
tf.import_graph_def(graph_def, name="prefix")
return graph
graph = load_graph("./model_original.pb")
# We can verify that we can access the list of operations in the graph
#for op in graph.get_operations():
# print(op.name)
x = graph.get_tensor_by_name('prefix/net_input:0')
y = graph.get_tensor_by_name('prefix/net_output:0')
print(x.shape)
print(y.shape)
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# launch a Session
with tf.Session(graph=graph) as sess:
y_out = sess.run(y, feed_dict={x: x_test})
然后尝试查看准确性,如下所示:
from mlxtend.evaluate import confusion_matrix
cm = confusion_matrix(np.argmax(y_test, axis=1),np.argmax(y_out, axis=1))
def accuracy(confusion_matrix):
diagonal_sum = confusion_matrix.trace()
sum_of_all_elements = confusion_matrix.sum()
return diagonal_sum / sum_of_all_elements
accuracy(cm) * 100