我正在尝试按照文档中的说明将数据并行传输到GRU的网络中,并且不断收到相同的错误
"""Defines the neural network, losss function and metrics"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, params, anchor_is_phrase):
"""
Simple LSTM, used to generate the LSTM for both the word and video
embeddings.
Args:
params: (Params) contains vocab_size, embedding_dim, lstm_hidden_dim
is_phrase: is word lstm or the vid lstm
"""
super(Net, self).__init__()
if anchor_is_phrase:
self.lstm = nn.DataParallel(nn.GRU(params.word_embedding_dim, params.hidden_dim, 1)).cuda()#, batch_first=True)
else:
self.lstm = nn.DataParallel(nn.GRU(params.vid_embedding_dim, params.hidden_dim, 1)).cuda() #, batch_first=True)
def forward(self, s, anchor_is_phrase = False):
"""
Forward prop.
"""
s, _ = self.lstm(s)
s.data.contiguous()
return s
该错误发生在先前代码中的s = _ = self.lstm(s)行:
here: s, _ = self.lstm(s)
s.data.contiguous()
return s
我收到以下错误消息:
s, _ = self.lstm(s)
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 152, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 162, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in parallel_apply
raise output
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 59, in _worker
output = module(*input, **kwargs)
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
result = self.forward(*input, **kwargs)
File "/home/pavelameen/miniconda3/envs/TD2/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 193, in forward
max_batch_size = input.size(0) if self.batch_first else input.size(1)
AttributeError: 'tuple' object has no attribute 'size'
有趣的部分是我尝试在第27行中输出s的类型,并且得到 PackedSequence ,为什么在lstm forward方法中将其转换为元组?