我尝试在pytorch中使用view(),但是我无法输入3个参数。我不知道为什么它会不断出现此错误?谁能帮我这个?
def forward(self, input):
lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
答案 0 :(得分:0)
您的input
似乎是一个numpy数组,而不是割炬张量。您需要先进行转换,例如input = torch.Tensor(input)
。