为什么我的算法中的LSTM单元损耗图显示出很多上下波动?

时间:2018-12-25 09:21:02

标签: python pytorch

Hier I attached Image of my Loss Diagram. I am very new to stack overflow, if possible, please try to answer here.

“我的数据”在这里代表一张大约3600个节点的图。而且我想借助Graph的最后一个节点的已知数据来预测下一个节点。

在这里,我对6个维度数据使用了一个LSTM Cell层。[x,y,u,v,p,type]。最后,我只想获得3D [u,v,p]输出。 其实效果很好。

但是当我看到相同图形的30,000次迭代后的lossHisto时。 就像我在上面附上一张照片一样。

我不知道,为什么我的损失没有如我们在0.025下所预期的那样减少。为何有时却显示出很大的波动而不是曲线。

import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
plt.switch_backend('agg')

class Sequence(nn.Module):
    def __init__(self):
        super(Sequence, self).__init__()
        self.lstmcell = nn.LSTMCell(5, 256)
        self.linear = nn.Linear(256, 3)

    def forward(self, x, path):

        outputs = []
        h_t = torch.zeros(1, 256)
        c_t = torch.zeros(1, 256)         

        for i in range(x.shape[0]):
            if  trainData[path[i],5] == 0:
                neighbours = graphDict[str(path[i])]
                y = graphFlags[neighbours]
                idx = np.where(y !=0)[0]
                z = torch.Tensor(x[neighbours[idx],:])
                u = torch.mean(z,0)
            else:              
                u = (x[path[i],:]).clone()
            input_t = u.view(1,5)
            h_t, c_t = self.lstmcell(input_t, (h_t, c_t))          
            output = torch.tanh(self.linear(h_t))
            outputs += [output]
            graphFlags[path[i]] = 1
            x[path[i], 0:3] = output     
        outputs = torch.cat((outputs), 0)
        return outputs

for t in range(30000):
    graphFlags = np.copy(trainData[:,5])
    trueData_ = torch.Tensor(trueData[:, 2:5])

    binary_repr = np.unpackbits(np.copy(trainData)[:,5].astype(np.uint8).reshape(-1,1), axis=1)[:,6:]   
    x = np.concatenate((np.copy(trainData)[:,2:5], binary_repr), axis=1)
    x = torch.Tensor(x)

    optimizer.zero_grad()
    out = seq(x, path)

    loss = criterion(out, trueData_)
    counter+=1
    loss.backward()
    optimizer.step()
    lossHisto.append(loss.item())

    if counter == 100:

        saveName = 'out'
        SaveName = 'lossHisto'
        torch.save(out, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
        BILD = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
        plt.figure(figsize=(20,10))
        plt.scatter(trueData[:,0],trueData[:,1],c=torch.Tensor(BILD.data[:,2]),marker='.')
        plt.colorbar()   
        plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test1/out%d.png"%t)   
        plt.clf()
        torch.save(lossHisto, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
        LOSS = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
        plt.plot(LOSS)
        plt.grid()
        plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test/lossHisto%d.png"%t)
        counter = 0

    print(loss.item()) 

0 个答案:

没有答案