我训练的LSTM的批次大小为128,而在测试期间,我的批次大小为1,为什么会出现此错误?我想在进行测试时初始化隐藏的大小吗?
这是我正在使用的代码,由于 batch_first = True <,我将隐藏状态 init_hidden 初始化为(层数,批处理大小,隐藏大小)。 / em>
class ImageLSTM(nn.Module):
def __init__(self, n_inputs:int=49,
n_outputs:int=4096,
n_hidden:int=256,
n_layers:int=1,
bidirectional:bool=False):
"""
Takes a 1D flatten images.
"""
super(ImageLSTM, self).__init__()
self.n_inputs = n_inputs
self.n_hidden = n_hidden
self.n_outputs = n_outputs
self.n_layers = n_layers
self.bidirectional = bidirectional
self.lstm = nn.LSTM( input_size=self.n_inputs,
hidden_size=self.n_hidden,
num_layers=self.n_layers,
dropout = 0.5 if self.n_layers>1 else 0,
bidirectional=self.bidirectional,
batch_first=True)
if (self.bidirectional):
self.FC = nn.Sequential(
nn.Linear(self.n_hidden*2, self.n_outputs),
nn.Dropout(p=0.5),
nn.Sigmoid()
)
else:
self.FC = nn.Sequential(
nn.Linear(self.n_hidden, self.n_outputs),
# nn.Dropout(p=0.5),
nn.Sigmoid()
)
def init_hidden(self, batch_size, device=None): # input 4D tensor: (batch size, channels, width, height)
# initialize the hidden and cell state to zero
# vectors:(number of layer, batch size, number of hidden nodes)
if (self.bidirectional):
h0 = torch.zeros(2*self.n_layers, batch_size, self.n_hidden)
c0 = torch.zeros(2*self.n_layers, batch_size, self.n_hidden)
else:
h0 = torch.zeros(self.n_layers, batch_size, self.n_hidden)
c0 = torch.zeros(self.n_layers, batch_size, self.n_hidden)
if device is not None:
h0 = h0.to(device)
c0 = c0.to(device)
self.hidden = (h0,c0)
def forward(self, X): # X: tensor of shape (batch_size, channels, width, height)
# forward propagate LSTM
lstm_out, self.hidden = self.lstm(X, self.hidden) # lstm_out: tensor of shape (batch_size, seq_length, hidden_size)
# Decode the hidden state of the last time step
out = self.FC(lstm_out[:, -1, :])
return out
答案 0 :(得分:0)
请编辑您的帖子并添加代码。您如何初始化隐藏状态?您的造型看起来如何。
hidden[0]
不是您的hidden-size,而是lstm的隐藏状态。隐藏状态的形状必须像这样初始化:
hidden = ( torch.zeros((batch_size, layers, hidden_size)), torch.zeros((layers, batch_size, hidden_size)) )
您似乎已正确执行此操作。但是错误告诉您,给了一批大小为1的批处理(因为正如您所说的,您只想使用一个样本进行测试),但是隐藏状态是使用batch-size = 128初始化的。
因此,我想(请添加代码)您已硬编码批次大小=128。不要这样做。由于每次正向传递都必须重新初始化隐藏状态,因此您可以执行以下操作:
...
def forward(self, x):
batch_size = x.shape[0]
hidden = (torch.zeros(self.layers, batch_size, self.hidden_size).to(device=device), torch.zeros(self.layers, batch_size, self.hidden_size).to(device=device))
output, hidden = lstm(x, hidden)
# then do what every you want with the output
我猜这是导致此错误的原因,但也请发布您的代码!