“我的数据”在这里代表一张大约3600个节点的图。而且我想借助Graph的最后一个节点的已知数据来预测下一个节点。
在这里,我对6个维度数据使用了一个LSTM Cell层。[x,y,u,v,p,type]。最后,我只想获得3D [u,v,p]输出。 其实效果很好。
但是当我看到相同图形的30,000次迭代后的lossHisto时。 就像我在上面附上一张照片一样。
我不知道,为什么我的损失没有如我们在0.025下所预期的那样减少。为何有时却显示出很大的波动而不是曲线。
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
plt.switch_backend('agg')
class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstmcell = nn.LSTMCell(5, 256)
self.linear = nn.Linear(256, 3)
def forward(self, x, path):
outputs = []
h_t = torch.zeros(1, 256)
c_t = torch.zeros(1, 256)
for i in range(x.shape[0]):
if trainData[path[i],5] == 0:
neighbours = graphDict[str(path[i])]
y = graphFlags[neighbours]
idx = np.where(y !=0)[0]
z = torch.Tensor(x[neighbours[idx],:])
u = torch.mean(z,0)
else:
u = (x[path[i],:]).clone()
input_t = u.view(1,5)
h_t, c_t = self.lstmcell(input_t, (h_t, c_t))
output = torch.tanh(self.linear(h_t))
outputs += [output]
graphFlags[path[i]] = 1
x[path[i], 0:3] = output
outputs = torch.cat((outputs), 0)
return outputs
for t in range(30000):
graphFlags = np.copy(trainData[:,5])
trueData_ = torch.Tensor(trueData[:, 2:5])
binary_repr = np.unpackbits(np.copy(trainData)[:,5].astype(np.uint8).reshape(-1,1), axis=1)[:,6:]
x = np.concatenate((np.copy(trainData)[:,2:5], binary_repr), axis=1)
x = torch.Tensor(x)
optimizer.zero_grad()
out = seq(x, path)
loss = criterion(out, trueData_)
counter+=1
loss.backward()
optimizer.step()
lossHisto.append(loss.item())
if counter == 100:
saveName = 'out'
SaveName = 'lossHisto'
torch.save(out, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
BILD = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+saveName+str(t))
plt.figure(figsize=(20,10))
plt.scatter(trueData[:,0],trueData[:,1],c=torch.Tensor(BILD.data[:,2]),marker='.')
plt.colorbar()
plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test1/out%d.png"%t)
plt.clf()
torch.save(lossHisto, '/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
LOSS = torch.load('/mnt/fs2/home/bhavesh/Ergebnisse/test1/'+SaveName+str(t))
plt.plot(LOSS)
plt.grid()
plt.savefig("/mnt/fs2/home/bhavesh/Ergebnisse/test/lossHisto%d.png"%t)
counter = 0
print(loss.item())