我试图简单地绘制几条曲线(你们中的一些人可能认出了caffe的lr_policy)
import numpy as np
import matplotlib.pyplot as plt
base_lr=0.0003
gamma=0.0001
epoch=10000
step=5*epoch
stepsize=step
power=0.5
max_iter=400000
iter=np.arange(0,max_iter)
fixed_policy=np.ones(max_iter)*base_lr
step_policy=base_lr*(gamma**(np.floor(iter/step)))
exp_policy=base_lr*(gamma**iter)
inv_policy=base_lr*((1+gamma*iter)**(-power))
poly_policy=base_lr*(1-iter/max_iter)**(power)
sigmoid_policy=base_lr*(1/(1+np.exp(-gamma * (iter -stepsize))))
print iter.shape
print fixed_policy.shape
print step_policy.shape
print exp_policy.shape
print inv_policy.shape
print poly_policy.shape
print sigmoid_policy.shape
iter = iter
plt.plot(iter, fixed_policy, label='fixed_policy')
plt.plot(iter, step_policy, label='step_policy')
plt.plot(iter, exp_policy, label='exp_policy')
plt.plot(iter, inv_policy, label='inv_policy')
plt.plot(iter, poly_policy, label='poly_policy')
plt.plot(iter, sigmoid_policy, label='sigmoid_policy')
plt.legend()
plt.yscale('log')
plt.show()
当我执行它时,它会回到我身边:
/usr/lib/python2.7/dist-packages/numpy/ma/core.py:3847: UserWarning: Warning: converting a masked element to nan.
warnings.warn("Warning: converting a masked element to nan.")
Exception in Tkinter callback
Traceback (most recent call last):
File "/usr/lib/python2.7/lib-tk/Tkinter.py", line 1489, in __call__
return self.func(*args)
File "/usr/lib/pymodules/python2.7/matplotlib/backends/backend_tkagg.py", line 276, in resize
self.show()
File "/usr/lib/pymodules/python2.7/matplotlib/backends/backend_tkagg.py", line 348, in draw
FigureCanvasAgg.draw(self)
File "/usr/lib/pymodules/python2.7/matplotlib/backends/backend_agg.py", line 451, in draw
self.figure.draw(self.renderer)
File "/usr/lib/pymodules/python2.7/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/usr/lib/pymodules/python2.7/matplotlib/figure.py", line 1034, in draw
func(*args)
File "/usr/lib/pymodules/python2.7/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/usr/lib/pymodules/python2.7/matplotlib/axes.py", line 2086, in draw
a.draw(renderer)
File "/usr/lib/pymodules/python2.7/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/usr/lib/pymodules/python2.7/matplotlib/axis.py", line 1096, in draw
tick.draw(renderer)
File "/usr/lib/pymodules/python2.7/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/usr/lib/pymodules/python2.7/matplotlib/axis.py", line 241, in draw
self.label1.draw(renderer)
File "/usr/lib/pymodules/python2.7/matplotlib/artist.py", line 55, in draw_wrapper
draw(artist, renderer, *args, **kwargs)
File "/usr/lib/pymodules/python2.7/matplotlib/text.py", line 598, in draw
ismath=ismath, mtext=self)
File "/usr/lib/pymodules/python2.7/matplotlib/backends/backend_agg.py", line 169, in draw_text
return self.draw_mathtext(gc, x, y, s, prop, angle)
File "/usr/lib/pymodules/python2.7/matplotlib/backends/backend_agg.py", line 159, in draw_mathtext
y = np.round(y - oy + yd)
File "/usr/lib/python2.7/dist-packages/numpy/core/fromnumeric.py", line 2629, in round_
File "/usr/lib/python2.7/dist-packages/numpy/ma/core.py", line 4855, in round
result._mask = self._mask
AttributeError: 'numpy.float64' object has no attribute '_mask'
我尝试使用.tolist()更改为float,但它仍然会返回相同的错误。
任何提示?
谢谢。