我试图在一个图中绘制3条损耗曲线。我有几个问题。
我无法获得平滑的曲线,相反,它以一条直线连接点对点
如何更改轴的比例,以显示由于太小而消失的MSE损失?
epochs=list(range(5000,50001,5000))
print(epochs)
mae_loss=[0.500225365,
0.000221096,
0.000060971,
0.000060323,
0.000059905,
0.000059579,
0.000059274,
0.000058972,
0.000058697,
0.000058476]
mse_loss=[0.135419831,
0.018331185,
0.002481434,
0.000335913,
0.000045486,
0.000006180,
0.000000867,
0.000000147,
0.000000042,
0.000000042]
rmse_loss=[0.500225306,
0.000293739,
0.000126985,
0.000121944,
0.000119484,
0.000117791,
0.000116400,
0.000115198,
0.000114148,
0.000113228]
plt.plot(epochs, mae_loss, 'b', label='MAE')
plt.plot(epochs, mse_loss, 'r', label='MSE')
plt.plot(epochs, mse_loss, 'g', label='RMSE')
plt.legend()
plt.show()```
答案 0 :(得分:2)
您将需要一些插值方法才能获得平滑的样条曲线/曲线。这本身就是一个不同的问题。我将回答有关不同比例的问题。由于数据量级的差异很大,因此在这种情况下,最好的解决方案是使用semilogy
使用对数y刻度。附注:您在最后的绘图行中写的是mse_loss
而不是rmse_loss
。
plt.semilogy(epochs, mae_loss, 'b', label='MAE')
plt.semilogy(epochs, mse_loss, 'r', label='MSE')
plt.semilogy(epochs, rmse_loss, 'g', label='RMSE')
plt.legend()
plt.show()
答案 1 :(得分:1)
要平滑图:
import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import make_interp_spline, BSpline
def create_spline_from(x, y, resolution):
new_x = np.linspace(x[0], x[-1], resolution)
y_spline = make_interp_spline(x, y, k=3)
new_y= y_spline(new_x)
return (new_x, new_y)
epochs=list(range(5000,50001,5000))
print(epochs)
mae_loss=[0.500225365,
0.000221096,
0.000060971,
0.000060323,
0.000059905,
0.000059579,
0.000059274,
0.000058972,
0.000058697,
0.000058476]
mse_loss=[0.135419831,
0.018331185,
0.002481434,
0.000335913,
0.000045486,
0.000006180,
0.000000867,
0.000000147,
0.000000042,
0.000000042]
rmse_loss=[0.500225306,
0.000293739,
0.000126985,
0.000121944,
0.000119484,
0.000117791,
0.000116400,
0.000115198,
0.000114148,
0.000113228]
x, y = create_spline_from(epochs, mae_loss, 50)
plt.plot(x, y, 'b', label='MAE')
x, y = create_spline_from(epochs, mse_loss, 50)
plt.plot(x, y, 'r', label='MSE')
x, y = create_spline_from(epochs, rmse_loss, 50)
plt.plot(x, y, 'g', label='RMSE')
plt.legend()
plt.show()