pyplot创造了许多不同的线条

时间:2018-04-14 01:50:44

标签: python matplotlib

我试图绘制一些函数,这些函数由于某种原因看起来像是由较小的模糊线组成而不是单行。

This is the image

这两个图上的四个函数只是我在for循环中创建的数字列表。

for i in range(iters):

    #zero the gradients
    amsgrad.zero_grad()
    adam.zero_grad()

    #Perform an optimization step
    adam.step(closure_adam)
    amsgrad.step(closure_amsgrad)

    #Clamp the variables between -1 and 1
    x_var_adam.data = x_var_adam.data.clamp(-1,1)
    x_var_amsgrad.data = x_var_amsgrad.data.clamp(-1,1)

    #Calculate the regret
    adam_regret = regret(total_loss_adam,total_min_loss_adam,t)
    ams_regret = regret(total_loss_amsgrad,total_min_loss_amsgrad,t)

    #Store regret
    regret_adam_hist.append(adam_regret)
    regret_amsgrad_hist.append(ams_regret)

    #Store the x_t values
    x_var_adam_hist.append(x_var_adam.data[0])
    x_var_amsgrad_hist.append(x_var_amsgrad.data[0])

    #Adjust learning rate by dividing by sqrt(t)
    adjust_learning_rate(adam,my_lr,t)
    adjust_learning_rate(amsgrad,my_lr,t)
    t+=1

# this is the second photo
x = np.arange(0,iters)
plt.clf()
plt.xlabel("Iterations")
plt.ylabel("$R_t/t$")
plt.plot(x,regret_adam_hist, label="adam", c='b', ls='--')
plt.plot(x,regret_amsgrad_hist,label="amsgrad",c='g')
plt.axis([0, iters, 0, 0.5])
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.legend(loc='best')
plt.rcParams["figure.figsize"] = (15,10)
plt.show()

plt.clf()

# this is the first photo
x = np.arange(0,iters+1)
plt.plot(x, x_var_adam_hist, label="adam", c='b', ls='--')
plt.xlabel("Iterations")
plt.ylabel("$x_t$")
plt.plot(x, x_var_amsgrad_hist, label="amsgrad", c='g', ls='dotted')
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.rcParams["figure.figsize"] = (15,10)
plt.legend(loc='best')
plt.show()

我不确定为什么pyplot会给我这些看起来像计算多行的图表。这是上面绘制的遗憾功能:

Regret

这些都应该看起来像线条流畅。我检查过以下内容:

for n in regret_amsgrad_hist:
assert (type(n) == float)
assert (type(regret_amsgrad_hist) == list)
assert (len(regret_amsgrad_hist) == 60000)

我不知道是什么原因导致它变得越来越严重"至于看起来像一个小于或等于功能或东西。有没有人有任何想法为什么我的图表不是两个图中的简单2平滑函数?

我猜这些函数唯一不寻常的是每101次迭代,函数在相反的方向上出现。以下是第二张图片(第二张图片)的一些示例值(平均后悔enter image description here):

https://pastebin.com/UrYGkFHT

编辑:我已经意识到我的功能没有任何问题,但它似乎将振荡压缩成一条粗线。有什么方法可以避免这种情况吗?例如:

此: iters=60000

使用较少的示例看起来像这样:

iters=3000

2 个答案:

答案 0 :(得分:0)

由于您的绘图样本高度振荡,这是使用实线时pyplot的预期行为。

如果由于振荡而想要表示没有这条大实线的振荡样品,最好的选择是只使用符号而不使用线。

答案 1 :(得分:0)

之所以摆动如此之多,原因在于这个功能实际上是在振荡很多。在我试图重现的图表中,图表似乎没有振荡。这是因为他们在x轴上使用了 plt.semilogx ,它将振荡缩小为实线。

解决方案是在x轴上使用plt.semilogx。