对齐上下轴的刻度坐标

时间:2019-02-08 17:48:13

标签: python matplotlib alignment axes

我想构建一个图形,在其底部x轴上具有一组刻度线,在其顶部x轴上具有与底部刻度线对齐的另一组刻度线。特别是在我的情况下,这些是批次和时期。对于底部的每个n批处理点(不一定是刻度),我希望在顶部有一个纪元刻度。考虑以下示例:

import numpy as np
import matplotlib.pyplot as plt

batches = np.arange(1,101)
epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]
accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))
loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))

fig, ax1 = plt.subplots( nrows=1, ncols=1 )
ax2 = ax1.twinx()
ax3 = ax1.twiny()

ax1.set_xlabel('batches')
ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_ylabel('accuracy')
ax1.grid()

ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))

ax3.set_xlabel('epochs')
ax3.set_xticks(epoch_ends)
ax3.set_xticklabels(range(1, len(epoch_ends)+1))

acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')

lns = acc_plt+loss_plt
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=2)

plt.show()

batchesepoch_ends分别看起来像这样

[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94  95  96  97  98  99 100]
[ 10  20  30  40  50  60  70  80  90 100]

因此,我希望纪元刻度1与批处理x-coorrdiante 10、2与20等对齐。

但是,正如您在图片中看到的那样,它们没有对齐。

enter image description here

要使此功能生效,我需要更改什么代码?

2 个答案:

答案 0 :(得分:2)

这是对齐它们的一种方法。想法如下:

  • 首先使用(ax1)在较低的x轴上绘制数据
  • 然后使用ax3.set_xlim(ax1.get_xlim())将x轴的上限设置为与x轴的下限相同。
  • 然后在对应于较低x轴值(10、20、30,...,90、100)的位置设置较高x轴的刻度
  • 最后,使用ax3.set_xticklabels()重命名刻度标签。

代码是这里:我用注释#替换了代码中已有的部分。

# imports and define data and compute accuracy and loss here

# Initiate figure and axis objects here

ax1.set_xlabel('batches')
ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_ylabel('accuracy')
ax1.grid()

acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')

ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))

new_tick_locations = np.arange(1, 11)*10
new_tick_labels = range(1, 11)

ax3.set_xlabel('epochs')
ax3.set_xlim(ax1.get_xlim())
ax3.set_xticks(new_tick_locations)
ax3.set_xticklabels(new_tick_labels)

# Set legends and show the plot

enter image description here

答案 1 :(得分:1)

这使您无法控制刻度线(您没有说是否需要),但是会像您想要的那样将两个x轴对齐: (注释新行)

import numpy as np
import matplotlib.pyplot as plt

batches = np.arange(1,101)
epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]
accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))
loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))

fig, ax1 = plt.subplots( nrows=1, ncols=1 )
ax2 = ax1.twinx()
ax3 = ax1.twiny()

ax1.set_xlabel('batches')
#ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_xlim(0, 100)                               # Set xlim
ax1.set_ylabel('accuracy')
ax1.grid()

ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))

ax3.set_xlabel('epochs')
#ax3.set_xticks(epoch_ends)
#ax3.set_xticklabels(range(1, len(epoch_ends)+1))
ax3.set_xlim(0, 10)                                # Set xlim

acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')

lns = acc_plt+loss_plt
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=2)

plt.show()