我写了一个函数来绘制图像的一些类预测。我似乎无法弄清楚如何正确可视化所有内容。我有两个具体的问题:1)我无法获得正确的长宽比(底部行),以及2)我无法获得刻度线以使底部行旋转全部绘制而不是仅绘制最后一个图(但长宽比问题似乎也是如此)。第一行-图像本身-似乎绘制得很好。
def plot_classifier_predictions(classifier: Sequential, model_name: str, compression_factor: float, x: np.ndarray, y_true: np.ndarray,
examples=5, random=True, save=True, plot=True):
# Make predictions
y_pred = classifier.predict(x=x)
# Set indices
indices = rnd.sample(range(len(x)), examples) if random else [i for i in range(examples)]
# Get image dimension
image_dim = x.shape[1]
# Plot parameters
plot_count = examples * 2
row_count = 2
col_count = int(ceil(plot_count / row_count))
# Initialize axes
fig, subplot_axes = plt.subplots(row_count,
col_count,
squeeze=True,
figsize=(16, 10),
constrained_layout=True)
# Set colors
colors = sns.color_palette('pastel', n_colors=len(dat.CLASSES))
# Fill axes
for i in range(plot_count):
row = i // col_count
col = i % col_count
original_image = x[indices[col]]
ax = subplot_axes[row][col]
# First row: show original images
if row == 0:
ax.set_title("Image")
ax.imshow(original_image)
ax.axis('off')
# Second row: show predictions
else:
ax.set_title("Predictions")
ax.bar(x=range(len(dat.CLASSES)), height=y_pred[indices[col]], color=colors)
ax.set_xticks(ticks=range(len(dat.CLASSES)))
ax.set_xticklabels(dat.CLASSES)
ax.set_aspect(2)
plt.xticks(rotation=45)
plt.ylim(0,1.0)
# General make-up
plt.tight_layout()
# Title
plt.suptitle(
"Predictions (image dim {} - compression {})".format(image_dim, compression_factor),
fontweight='bold')
这是当前结果: