我正在训练一个关于词嵌入的简单神经网络。我试图每 1000 个时期保存一个嵌入图:
if epo%1000 == 0:
print(f'Epoch {epo}')
#save plots
W = net.fc1.weight
W = W.detach().numpy()
svd = decomposition.TruncatedSVD(n_components=2)
W1_dec = svd.fit_transform(W.T)
x = W1_dec[:,0]
y = W1_dec[:,1]
plot = sns.scatterplot(x, y)
for i in range(0,W1_dec.shape[0]):
plot.text(x[i], y[i]+2e-2, list(vocabulary.keys())[i], horizontalalignment='center', size='small', color='black', weight='semibold')
plt.savefig(f'images/epoch{epo}.png')
最初的训练和情节会很好,但过一段时间它们会显着放缓。我以前在每次使用整个计算图时都遇到过这个问题,但我在这里看不到哪里可以做同样的事情。
我尝试删除绘图部分(在“#save plots”下),一切正常并且不会变慢。
更新
看起来最后一行是减慢进程的那一行:
plt.savefig(f'images/epoch{epo}.png')
不知道怎么解决