如何在Pytorch中实现卷积GRU / LSTM / RNN?

时间:2019-08-15 15:07:35

标签: python machine-learning deep-learning pytorch

我想设置标准的LSTM / GRU / RNN,但用卷积交换线性函数。可以在Pytorch中以一种干净高效的方式进行此操作吗?

理想情况下,它仍然可以用于打包,改变序列长度等。

一个简单的方式来传递数据的小示例代码将非常有用,例如:

# Based on Robert Guthrie tutorial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pdb import set_trace as st

torch.manual_seed(1)

def step_by_step(net,sequence,hidden):
    '''
    an example of an LSTM processing all the sequence one token at a time (one time step at a time)
    '''
    ## process sequence one element at a time
    print()
    print('start processing sequence')
    for i, token in enumerate(sequence):
        print(f'-- i = {i}')
        #print(f'token.size() = {token.size()}')
        ## to add fake batch_size and fake seq_len
        h_n, c_n = hidden # hidden states, cell state
        processed_token = token.view(1, 1, -1) # torch.Size([1, 1, 3])
        print(f'processed_token.size() = {processed_token.size()}')
        print(f'h_n.size() = {h_n.size()}')
        #print(f'processed_token = {processed_token}')
        #print(f'h_n = {h_n}')
        # after each step, hidden contains the hidden state.
        out, hidden = lstm(processed_token, hidden)
    ## print results
    print()
    print(out)
    print(hidden)

def whole_seq_all_at_once(lstm,sequence,hidden):
    '''
        alternatively, we can do the entire sequence all at once.
        the first value returned by LSTM is all of the hidden states throughout
        #the sequence.
        The second is just the most recent hidden state
        (compare the last slice of "out" with "hidden" below, they are the same)
        The reason for this is that:
        "out" will give you access to all hidden states in the sequence
        "hidden" will allow you to continue the sequence and backpropagate,
        by passing it as an argument  to the lstm at a later time
        Add the extra 2nd dimension
    '''
    h, c = hidden
    Tx = len(sequence)
    ## concatenates list of tensors in the dim 0, i.e. stacks them downwards creating new rows
    sequence = torch.cat(sequence) # (5, 3)
    ## add a singleton dimension of size 1
    sequence = sequence.view(len(sequence), 1, -1) # (5, 1, 3)
    print(f'sequence.size() = {sequence.size()}')
    print(f'h.size() = {h.size()}')
    print(f'c.size() = {c.size()}')
    out, hidden = lstm(sequence, hidden)
    ## "out" will give you access to all hidden states in the sequence
    print()
    print(f'out = {out}')
    print(f'out.size() = {out.size()}') # (5, 1, 25)
    ##
    h_n, c_n = hidden
    print(f'h_n = {h_n}')
    print(f'h_n.size() = {h_n.size()}')
    print(f'c_n = {c_n}')
    print(f'c_n.size() = {c_n.size()}')

if __name__ == '__main__':
    ## model params
    hidden_size = 6
    input_size = 3
    lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size)
    ## make a sequence of length Tx (list of Tx tensors)
    Tx = 5
    sequence = [torch.randn(1, input_size) for _ in range(Tx)]  # make a sequence of length 5
    ## initialize the hidden state.
    hidden = (torch.randn(1, 1, hidden_size), torch.randn(1, 1, hidden_size))
    #step_by_step(lstm,sequence,hidden)
    whole_seq_all_at_once(lstm,sequence,hidden)
    print('DONE \a')

交叉发布:

0 个答案:

没有答案