我是 PyTorch 的新手,想为 MNIST 数据集定制一个 LSTM 模型。这是我的 MNIST 数据集的 2 层 LSTM 模型。第一类是定制的LSTM Cell,第二类是LSTM模型。
class Cust_LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, nlayers, dropout):
""""Constructor of the class"""
super(Cust_LSTMCell, self).__init__()
self.nlayers = nlayers
self.dropout = nn.Dropout(p=dropout)
ih, hh = [], []
for i in range(nlayers):
ih.append(nn.Linear(input_size if i == 0 else hidden_size, 4 * hidden_size))
hh.append(nn.Linear(hidden_size, 4 * hidden_size))
self.w_ih = nn.ModuleList(ih)
self.w_hh = nn.ModuleList(hh)
self.sig = nn.Sigmoid()
self.tanh = nn.Tanh()
def forward(self, input, hidden): # input shld be with shape (batch_size, input_size)
""""Defines the forward computation of the LSTMCell"""
hy, cy = [], []
for i in range(self.nlayers):
hx, cx = hidden[0][i], hidden[1][i]
gates = self.w_ih[i](input) + self.w_hh[i](hx)
i_gate, f_gate, c_gate, o_gate = gates.chunk(4, 1)
i_gate = self.sig(i_gate)
f_gate = self.sig(f_gate)
c_gate = self.tanh(c_gate)
o_gate = self.sig(o_gate)
ncx = (f_gate * cx) + (i_gate * c_gate)
nhx = o_gate * self.tanh(ncx)
cy.append(ncx)
hy.append(nhx)
input = self.dropout(nhx)
hy, cy = torch.stack(hy, 0), torch.stack(cy, 0)
return hy, cy
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = Cust_LSTMCell(input_size, hidden_size, num_layers, 0)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x): # x is a tensor with shape (sequence_length, batch_size, input_size)
# Set initial hidden and cell states
hx = torch.zeros(num_layers, x.size(1), self.hidden_size).to(device)
cx = torch.zeros(num_layers, x.size(1), self.hidden_size).to(device)
for i in range(x.size()[0]): # for each time step
hx, cx = self.lstm(x[i], (hx, cx)) # output hx/cx of shape (batch_size, hidden_size)
return self.fc(hx[1])
训练损失:
Epoch [1/2], Step [100/600], Loss: 2.3095
Epoch [1/2], Step [200/600], Loss: 2.2956
Epoch [1/2], Step [300/600], Loss: 2.2852
Epoch [1/2], Step [400/600], Loss: 2.3077
Epoch [1/2], Step [500/600], Loss: 2.3027
Epoch [1/2], Step [600/600], Loss: 2.3133
Epoch [2/2], Step [100/600], Loss: 2.3056
Epoch [2/2], Step [200/600], Loss: 2.2853
Epoch [2/2], Step [300/600], Loss: 2.3103
Epoch [2/2], Step [400/600], Loss: 2.3044
Epoch [2/2], Step [500/600], Loss: 2.3034
Epoch [2/2], Step [600/600], Loss: 2.3049
我已经检查了代码并尝试修复它很多次,但仍然找不到问题。有人可以帮忙吗? ;)