我已经培训了MobilenetV1的精简版,现在我想在评估步骤中可视化所有层的学习过滤器。我使用的Mobilenet来自:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py
我已经尝试过仅与简单CNN相关的tensorflow调试器工具和其他技术。我找不到对当前情况有用的任何信息。
这是我的评估代码:
with tf.Graph().as_default() as graph:
tf.logging.set_verbosity(tf.logging.INFO)
dataset = get_split('validation', dataset_dir)
images, raw_images, labels = load_batch(dataset, batch_size=batch_size, is_training=False)
num_batches_per_epoch = dataset.num_samples / batch_size
num_steps_per_epoch = num_batches_per_epoch
with slim.arg_scope(mobilenet_arg_scope()):
logits, end_points = mobilenet(images, num_classes=dataset.num_classes, is_training=False)
variables_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, checkpoint_file)
probabilities = end_points['Predictions']
predictions = tf.argmax(probabilities, 1)
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, labels)
metrics_op = tf.group(accuracy_update)
global_step = get_or_create_global_step()
global_step_op = tf.assign(global_step,
global_step + 1)
def eval_step(sess, metrics_op, global_step):
start_time = time.time()
_, global_step_count, accuracy_value = sess.run([metrics_op, global_step_op, accuracy])
time_elapsed = time.time() - start_time
logging.info('Global Step %s: Streaming Accuracy: %.4f (%.2f sec/step)', global_step_count, accuracy_value,
time_elapsed)
return accuracy_value
tf.summary.scalar('Validation_Accuracy', accuracy)
matrix = tf.confusion_matrix(labels, predictions, num_classes=3)
image_tensor = draw_confusion_matrix(matrix)
image_summary = tf.summary.image('confusion_matrix', image_tensor)
my_summary_op = tf.summary.merge_all()
sv = tf.train.Supervisor(logdir=log_eval, summary_op=None, init_fn=restore_fn)
with sv.managed_session() as sess:
for step in xrange(int(num_batches_per_epoch * num_epochs)):
if step % num_batches_per_epoch == 0:
logging.info('Epoch: %s/%s', step / num_batches_per_epoch + 1, num_epochs)
logging.info('Current Streaming Accuracy: %.4f', sess.run(accuracy))
evaluating
if step % 10 == 0:
eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step)
summaries = sess.run(my_summary_op)
sv.summary_computed(sess, summaries)
else:
eval_step(sess, metrics_op=metrics_op, global_step=sv.global_step)
logging.info('Final eval Accuracy: %.4f', sess.run(accuracy))
我希望看到并保存输出,类似于keras如此轻松地完成输出:
https://blog.keras.io/how-convolutional-neural-networks-see-the-world.html