在子图中组合不同的展示和栏

时间:2019-06-02 07:58:33

标签: python python-3.x matplotlib

我写了一个函数来绘制图像的一些类预测。我似乎无法弄清楚如何正确可视化所有内容。我有两个具体的问题:1)我无法获得正确的长宽比(底部行),以及2)我无法获得刻度线以使底部行旋转全部绘制而不是仅绘制最后一个图(但长宽比问题似乎也是如此)。第一行-图像本身-似乎绘制得很好。

def plot_classifier_predictions(classifier: Sequential, model_name: str, compression_factor: float, x: np.ndarray, y_true: np.ndarray,
                             examples=5, random=True, save=True, plot=True):


    # Make predictions
    y_pred = classifier.predict(x=x)

    # Set indices
    indices = rnd.sample(range(len(x)), examples) if random else [i for i in range(examples)]

    # Get image dimension
    image_dim = x.shape[1]

    # Plot parameters
    plot_count = examples * 2
    row_count = 2
    col_count = int(ceil(plot_count / row_count))

    # Initialize axes
    fig, subplot_axes = plt.subplots(row_count,
                                     col_count,
                                     squeeze=True,
                                     figsize=(16, 10),
                                     constrained_layout=True)

    # Set colors
    colors = sns.color_palette('pastel', n_colors=len(dat.CLASSES))

    # Fill axes
    for i in range(plot_count):

        row = i // col_count
        col = i % col_count

        original_image = x[indices[col]]

        ax = subplot_axes[row][col]

        # First row: show original images
        if row == 0:
            ax.set_title("Image")
            ax.imshow(original_image)
            ax.axis('off')

        # Second row: show predictions
        else:
            ax.set_title("Predictions")

            ax.bar(x=range(len(dat.CLASSES)), height=y_pred[indices[col]], color=colors)
            ax.set_xticks(ticks=range(len(dat.CLASSES)))
            ax.set_xticklabels(dat.CLASSES)
            ax.set_aspect(2)
            plt.xticks(rotation=45)
            plt.ylim(0,1.0)



    # General make-up
    plt.tight_layout()

    # Title
    plt.suptitle(
        "Predictions (image dim {} - compression {})".format(image_dim, compression_factor),
        fontweight='bold')

这是当前结果:

enter image description here

0 个答案:

没有答案