我用Tensorflow编写了一个CNN网络,该网络工作正常,我想在测试阶段检查分类图像。
在我的数据集中,我有5个不同的类别,在测试阶段我正在寻找一种方法来保存每个类别的新文件夹中的分类图像,以检查我的网络结果或不正确的结果。
这是我在测试阶段的全部代码:
tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate for adam optimizer')
tf.app.flags.DEFINE_integer('num_classes', 3, 'Number of classes')
tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size')
tf.app.flags.DEFINE_float('keep_prob', 0.8, 'Dropout keep probability')
tf.app.flags.DEFINE_integer('num_channel',3 , 'Image channel, RGB=3, Grayscale=1')
tf.app.flags.DEFINE_integer('img_size', 80, 'Size of images')
tf.app.flags.DEFINE_string('test_file', 'data/test.txt', 'Test dataset file')
FLAGS = tf.app.flags.FLAGS
checkpoint_dir = '/home/xyrio/Desktop/classier/training/checkpoints/model_epoch.ckpt89'
def main(_):
x = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.img_size, FLAGS.img_size, FLAGS.num_channel], name='x')
y_true = tf.placeholder(tf.float32, shape=[None, FLAGS.num_classes], name='y_true')
y_true_cls = tf.argmax(y_true, axis=1)
filter_size_conv1 = 3
num_filters_conv1 = 32
filter_size_conv2 = 3
num_filters_conv2 = 32
filter_size_conv3 = 3
num_filters_conv3 = 64
filter_size_conv4 = 3
num_filters_conv4 = 128
filter_size_conv5 = 3
num_filters_conv5 = 256
fc_layer_size = 512
fc_layer_size2 = 128
def create_weights(shape):
return tf.Variable(tf.truncated_normal(shape, mean=0, stddev=0.01))
def create_biases(size):
return tf.Variable(tf.constant(0.01, shape=[size]))
def create_convolutional_layer(input, num_input_channels, conv_filter_size, num_filters, useBatchNorm=False,
usePooling=True):
weights = create_weights(shape=[conv_filter_size, conv_filter_size, num_input_channels, num_filters])
biases = create_biases(num_filters)
layer = tf.nn.conv2d(input=input, filter=weights, strides=[1, 1, 1, 1], padding='SAME')
layer += biases
layer = tf.nn.relu(layer)
if useBatchNorm == True:
layer = tf.layers.batch_normalization(layer)
if usePooling:
layer = tf.nn.max_pool(value=layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
return layer
def create_flatten_layer(layer):
layer_shape = layer.get_shape()
num_features = layer_shape[1:4].num_elements()
layer = tf.reshape(layer, [-1, num_features])
return layer
def create_fc_layer(input, num_inputs, num_outputs, useRelu=True, useDropout=False):
weights = create_weights(shape=[num_inputs, num_outputs])
biases = create_biases(num_outputs)
layer = tf.matmul(input, weights) + biases
if useRelu:
layer = tf.nn.relu(layer)
if useDropout == True:
layer = tf.nn.dropout(layer, keep_prob=FLAGS.keep_prob)
return layer
layer_conv1 = create_convolutional_layer(x, FLAGS.num_channel, filter_size_conv1, num_filters_conv1,
useBatchNorm=True, usePooling=True)
layer_conv2 = create_convolutional_layer(layer_conv1, num_filters_conv1, filter_size_conv2, num_filters_conv2,
useBatchNorm=True, usePooling=True)
layer_conv3 = create_convolutional_layer(layer_conv2, num_filters_conv2, filter_size_conv3, num_filters_conv3,
useBatchNorm=True, usePooling=True)
layer_conv4 = create_convolutional_layer(layer_conv3, num_filters_conv3, filter_size_conv4, num_filters_conv4,
useBatchNorm=True, usePooling=True)
layer_conv5 = create_convolutional_layer(layer_conv4, num_filters_conv4, filter_size_conv5, num_filters_conv5,
useBatchNorm=True, usePooling=True)
layer_flat = create_flatten_layer(layer_conv5)
layer_fc1 = create_fc_layer(layer_flat, layer_flat.get_shape()[1:4].num_elements(), fc_layer_size, useRelu=True,
useDropout=False)
layer_fc2 = create_fc_layer(layer_fc1, fc_layer_size, fc_layer_size2, useRelu=True, useDropout=True)
layer_fc3 = create_fc_layer(layer_fc2, fc_layer_size2, FLAGS.num_classes, useRelu=False)
y_pred = tf.nn.softmax(layer_fc3, name='y_pred', axis=1)
y_pred_cls = tf.argmax(y_pred, axis=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
test_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.test_file, num_classes=FLAGS.num_classes,
output_size=[FLAGS.img_size, FLAGS.img_size])
test_batches_per_epoch = np.floor(len(test_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
conf_mat = tf.confusion_matrix(y_true_cls,y_pred_cls,FLAGS.num_classes)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, checkpoint_dir)
# Start Testing
test_acc = 0.
test_count = 0
cm_total = None
for _ in range(test_batches_per_epoch):
batch_tx, batch_ty = test_preprocessor.next_batch(FLAGS.batch_size)
acc, conf_m = sess.run([accuracy, conf_mat],
feed_dict={x: batch_tx, y_true: batch_ty})
if cm_total is None:
cm_total = conf_m
else:
cm_total += conf_m
test_acc += acc
test_count += 1
test_acc /= test_count
print("{} Testing Accuracy = {:.2%}".format(datetime.now(), test_acc))
test_preprocessor.reset_pointer()
print(cm_total)
此代码用于测试数据,您可以看到我恢复了在训练和验证期间保存的检查点,之后,我使用了最佳检查点来预测我的测试数据。
batch_tx
是我的测试数据,batch_ty
是我的测试标签。
有谁知道我该怎么做?
提前致谢
答案 0 :(得分:1)
好的,经过上面的讨论,你添加了一行
sess.run(y_pred_cls,{x:batch_tx})
在您的混淆矩阵求和之后,现在您已经预测了标签。以下面的代码转换为np数组的格式打印出来。如果您的测试代码在单个线程中运行,并且它不会使测试批次混乱,那么现在您的预测标签与相同的顺序一样,图像显示在输入文件中。假设您的输入文件是.bin文件,您应该可以从中提取图像(使用PIL):
from PIL import Image
# your image dimensions here
width = 80
height = 80
channels = 3
# most labels are 1 byte
labelSize = 1
pixelSize = width * height * channels
recordSize = labelSize + pixelSize
label_names = ['cat', 'horse', 'dog'....]
predictions = [...] # put your predictions here
with open(inputFilename, "rb") as f:
allTheData = np.fromfile(f, 'u1')
numRecords = allTheData.shape[0] / recordSize
allTheData = allTheData.reshape(numRecords, recordSize)
for idx, d in enumerate(allTheData):
label = label_names[d[0]]
rgbData = d[1:] #records are label first, then all pixel data and rgb
predlabel = label_names[data_labels[idx]]
filename = "{}_pred{}_actual{}.png".format(idx, predlabel, label)
pictureA = rgbData.reshape(3, width, height)
pictureA = np.swapaxes(pictureA,0,1)
pictureA = np.swapaxes(pictureA,1,2)
pictureA = np.ndarray.flatten(pictureA)
imageA = Image.frombytes('RGB', (height, width), pictureA)
#display(imageA)
imageA.save(filename, "PNG")
请注意,在添加适当的标签名称和预测之前,上述代码不会运行。此外,如果输入文件是.csv,则必须稍微更改它的读数。