实现Pixel RNN的RowLSTM pytorch的前向传递

时间:2018-07-31 16:16:23

标签: python machine-learning lstm pytorch rnn

这个问题几乎说明了一切。我正在尝试在PyTorch中实现Pixel RNN的Row LSTM正向通过的门运算。我对第1行的第一个门的计算方式特别好奇,因为没有以前的单元状态或隐藏状态可以使用。

我知道这里有一个非常类似的问题,但是被接受的答案没有代码,这正是我要寻找的。

此外,我已经实现了自己的代码,但是训练起来很慢,这就是为什么我要查看某人是否具有更快的解决方案。

class RLSTM(nn.Module):
    def __init__(self,ch):
        super(RLSTM,self).__init__()
        self.ch=ch
        self.input_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda()
        self.state_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda() # error is here: hidPrev is an array - not a valid number of input channel
        self.cell_list = []

        # check if these convolutions are changing their weights     

    def forward(self, image):
       # print("starting forward")
       # if(self.ch==64):
      #      print self.input_to_state.weight[0][0][0][0]
        size = image.size()
        b = size[0]
        indvs = list(image.split(1,0))
        tensor_array = []

        for i in range(b):
            tensor_array.append(self.RowLSTM(indvs[i]))
        seq=tuple(tensor_array)
        trans = torch.cat(seq,0)
        global total
        total+=1
      #  print("finished forward")
        return trans.cuda() 
    def RowLSTM(self, image): 
     # input-to-state (K_is * x_i) : 1x3 convolution. generate h x n x n tensor. hxnxn tensor contains all i -> s info
       # print("Starting LSTM")
        self.cell_list=[]
        igates = []
       # print(image.size())
        n = image.size()[2]
        ch=image.size()[1]
        for i in range(n):
            if i==0:      
                # COULD BE THIS AREA, SHOULD EVERYTHING BE 0?   
                isgates = self.splitIS(self.input_to_state(image)) # convolve, then split into gates (4 per row) 
                #print("{} and {}".format(len(isgates),len(isgates[0])))
                #cell=RowLSTMCell(0,torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
                cell=RowLSTMCell(0,isgates[0][0],isgates[0][1],isgates[0][2],isgates[0][3],torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
                cell.c=isgates[0][0]*isgates[0][3]
                cell.h=torch.tanh(cell.c)*isgates[0][1]
                # now have dummy variables for first row
                self.cell_list.append(cell)       
            else:   
                cell_prev = self.cell_list[i-1]
                hid_prev = cell_prev.getHiddenState()
                ssgates = self.splitSS(self.state_to_state(hid_prev.unsqueeze(0)))
                gates = self.addGates(isgates, ssgates,i)
                ig, og, fg, gg = gates[0], gates[1], gates[2], gates[3]
                cell = RowLSTMCell(cell_prev, ig, og, fg, gg, 0 ,0) #MORE zeros
                cell.compute()

                self.cell_list.append(cell)

        # now have a list of all cell data, concatenate hidden state into 1 x h x n x n
        hidden_layers = []
        for i in range(n):
            hid = self.cell_list[i].h
            hidden_layers.append(torch.unsqueeze(hid,0))

        seq = tuple(hidden_layers)
        tensor = torch.cat(seq,3)
      #  print("finished lstm")
        #print(tensor.size())
        return tensor 

    def splitIS(self, tensor): #always going to be splitting into 4 pieces, so no need to add extra parameters
        inputStateGates={}
        size=tensor.size() # 1 x 4h x n x n
        out_ft=size[1] # get 4h for the nxnx4h tensor
        num=size[2] # get n for the nxn image
        hh=out_ft/4 # we want to split the tensor into 4, for the gates
        tensor = torch.squeeze(tensor).cuda() # 4h x n x n

        # First, split by row: Creates n tensors of 4h x n x 1
        rows = list(tensor.split(1,2))

        for i in range(num):
            # Each row is a tensor of 4h x n x 1, split it into 4 of h x n x 1
            row=rows[i]
          #  print("Each row using cuda: "+str(row.is_cuda))
            inputStateGates[i]=list(row.split(hh,0))

        return inputStateGates 


    def splitSS(self, tensor): # 1 x 4h x n x 1, create 4 of 1 x h x n x 1 
        size=tensor.size() 
        out_ft=size[1] # get 4h for the 1x4hxn tensor
        num=size[2] # get n for the 1xhxn row
        hh=out_ft/4 # we want to split the tensor into 4, for the gates
        tensor = tensor.squeeze(0).cuda() # 4h x n x 1
        splitted=list(tensor.split(hh,0))
        return splitted 


    def addGates(self, i2s,s2s,key):
        """ these dictionaries are of form {key : [[i], [o], [f], [g]]}
            we want to add pairwise elemeents """

        # i2s is of form key: [[i], [o], [f], [g]] where each gate is hxn
        # s2s is of form [[h,n],[h,n],[h,n], [h,n]]
        gateSum = []
        for i in range(4): # always of length 4, representing the gates
            gateSum.append(torch.sigmoid(i2s[key][i] + s2s[i]))
        return gateSum

这是每个单元格的类:

class RowLSTMCell(): #inherit torch.nn.LSTM?
    def __init__(self,prev_row, i, o, f, g, c, h):
        self.c=c
        self.h=h
        self.i=i
        self.i = self.i.cuda()
        self.o=o
        self.o = self.o.cuda()
        self.g=g
        self.g = self.g.cuda()
        self.f=f
        self.f = self.f.cuda()
        self.prev_row=prev_row 
    def getStateSize(self):
        return self._state_size

    def getOutputSize(self):
        return self._output_size

    def compute(self):
        c_prev = self.prev_row.getCellState()
        h_prev = self.prev_row.getHiddenState()   
        self.c = self.f * c_prev + self.i * self.g
        self.h = torch.tanh(self.c) * self.o
    def getHiddenState(self):
        return self.h

    def getCellState(self):
        return self.c

0 个答案:

没有答案