
时间:2019-11-27 13:16:54

标签: python python-3.x neural-network pytorch


def dataloader(messages, labels, sequence_length=30, batch_size=32, shuffle=False):
    Build a dataloader.
    if shuffle:
        indices = list(range(len(messages)))
        messages = [messages[idx] for idx in indices]
        labels = [labels[idx] for idx in indices]

    total_sequences = len(messages) #total number of twits

    for ii in range(0, total_sequences, batch_size):
        batch_messages = messages[ii: ii+batch_size]

        # First initialize a tensor of all zeros
        batch = torch.zeros((sequence_length, len(batch_messages)), dtype=torch.int64)
        for batch_num, tokens in enumerate(batch_messages):
            token_tensor = torch.tensor(tokens)
            # Left pad!
            start_idx = max(sequence_length - len(token_tensor), 0) #returns 0 is len(token_tensor) > seqeuence_length
            batch[start_idx:, batch_num] = token_tensor[:sequence_length] #replace each row in batch with the token

        label_tensor = torch.tensor(labels[ii: ii+len(batch_messages)])

        yield batch, label_tensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 3
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):

    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection

    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)        

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)        

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        return out

model_2 = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_2.parameters(), lr=learning_rate)
train_loader = dataloader(
            train_features, train_labels, batch_size=batch_size, sequence_length=20, shuffle=True)

# Train the model
total_step = 200

for epoch in range(num_epochs):

    for i, (text_batch, labels) in enumerate(train_loader):
        text_batch = text_batch.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model_2(text_batch)
        loss = criterion(outputs, labels)        

        # Backward and optimize

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))


RuntimeError                              Traceback (most recent call last)
<ipython-input-74-8a935e97b39c> in <module>
     11         # Forward pass
---> 12         outputs = model_2(text_batch)
     13         loss = criterion(outputs, labels)

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

<ipython-input-64-21fa163d5c93> in forward(self, x)
     19         # Forward propagate LSTM
---> 20         out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
     22         # Decode the hidden state of the last time step

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
    562             return self.forward_packed(input, hx)
    563         else:
--> 564             return self.forward_tensor(input, hx)
    566 class GRU(RNNBase):

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\rnn.py in forward_tensor(self, input, hx)
    541         unsorted_indices = None
--> 543         output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
    545         return output, self.permute_hidden(hidden, unsorted_indices)

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\rnn.py in forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices)
    521             hx = self.permute_hidden(hx, sorted_indices)
--> 523         self.check_forward_args(input, hx, batch_sizes)
    524         if batch_sizes is None:
    525             result = _VF.lstm(input, hx, self._get_flat_weights(), self.bias, self.num_layers,

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\rnn.py in check_forward_args(self, input, hidden, batch_sizes)
    494     def check_forward_args(self, input, hidden, batch_sizes):
    495         # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor]) -> None
--> 496         self.check_input(input, batch_sizes)
    497         expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)

~\Anaconda3\envs\thesis\lib\site-packages\torch\nn\modules\rnn.py in check_input(self, input, batch_sizes)
    143             raise RuntimeError(
    144                 'input must have {} dimensions, got {}'.format(
--> 145                     expected_input_dim, input.dim()))
    146         if self.input_size != input.size(-1):
    147             raise RuntimeError(

RuntimeError: input must have 3 dimensions, got 2

0 个答案:
