将深度迁移学习单GPU训练代码修改为多GPU训练时,代码测试部分出现如下错误:
Traceback (most recent call last):
File "../main.py", line 155, in <module>
main(args)
File "../main.py", line 57, in main
t_correct = test(args, model, tar_test_loader, cuda_stat)
File "../main.py", line 136, in test
s_output, _, _ = model(data, data, target)
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
return self.gather(outputs, self.output_device)
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 180, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 76, in gather
res = gather_map(outputs)
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
File "/home/maxin2/project/anaconda3/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py", line 71, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'int' object is not iterable
我找不到错误的原因。以下是我的测试代码的一部分:
def test(args, model, target_test_loader, cuda_stat):
model.eval()
test_loss = 0
correct = 0
criterion = torch.nn.CrossEntropyLoss(reduction='sum')
len_target_dataset = len(target_test_loader.dataset)
with torch.no_grad():
for data, target in target_test_loader:
if cuda_stat:
data, target = data.cuda(), target.cuda()
s_output, _, _ = model(data, data, target)
test_loss += criterion(s_output, target)# sum up batch loss
pred = torch.max(s_output, 1)[1] # get the index of the max log-probability
print(pred)
correct += torch.sum(pred == target)
test_loss /= len_target_dataset
print(args.test_dir, ' Test set:Loss: {} {:.6f} Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len_target_dataset,
100. * correct / len_target_dataset))
return correct