SGDClassifier可以节省每次迭代到数组的损失

时间:2019-01-27 13:29:24

标签: scikit-learn gradient-descent

当我在scikit-learn中训练SGDClassifier时,我可以打印出每次迭代的损失值(设置详细程度)。如何将值存储到数组中?

1 个答案:

答案 0 :(得分:0)

修改此post的答案。

import numpy as np
from io import StringIO
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
from tensorflow.keras.datasets import mnist

(x_tr, y_tr), (x_te, y_te) = mnist.load_dataset()
x_tr, x_te = x_tr.reshape(-1, 784), x_te.reshape(-1, 784)

拦截SGDClassifier的打印输出

old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()

通过将verbose设置为1,设置模型以打印其输出。

clf = SGDClassifier(verbose=1)
clf.fit(x_tr, y_tr)

获取SGDClassifier详细程度的输出

sys.stdout = old_stdout
loss_history = mystdout.getvalue()

创建一个列表来存储损耗值

loss_list = []

附加打印的损失值,该值存储在loss_history中

for line in loss_history.split('\n'):
    if(len(line.split("loss: ")) == 1):
        continue
    loss_list.append(float(line.split("loss: ")[-1]))

仅显示图

plt.figure()
plt.plot(np.arange(len(loss_list)), loss_list)
plt.xlabel("Time in epochs"); plt.ylabel("Loss")
plt.show()

要将损失值保存到数组中,

loss_list = np.array(loss_list)