我尝试在网络中使用这样的代码连接变量
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
x= torch.cat((x,angle),1) # from here I concat it.
x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
x = self.fc3(x)
然后我发现我的网络什么都不学,并且总是给出约50%的acc。所以我打印param.grad
正如我所料,他们都是南。有没有人以前遇到过这件事?
我之前没有连接就运行了代码并且运行良好。所以我想这就是摩擦,系统不会抛出任何错误或异常。如果需要任何其他备份信息,请告诉我。
谢谢。
答案 0 :(得分:1)
错误可能在您提供的代码之外的某处。尝试检查输入中是否存在nan,并检查损失函数是否不会导致nan。