这个问题几乎说明了一切。我正在尝试在PyTorch中实现Pixel RNN的Row LSTM正向通过的门运算。我对第1行的第一个门的计算方式特别好奇,因为没有以前的单元状态或隐藏状态可以使用。
我知道这里有一个非常类似的问题,但是被接受的答案没有代码,这正是我要寻找的。
此外,我已经实现了自己的代码,但是训练起来很慢,这就是为什么我要查看某人是否具有更快的解决方案。
class RLSTM(nn.Module):
def __init__(self,ch):
super(RLSTM,self).__init__()
self.ch=ch
self.input_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda()
self.state_to_state = torch.nn.Conv2d(self.ch,4*self.ch,kernel_size=(1,3),padding=(0,1)).cuda() # error is here: hidPrev is an array - not a valid number of input channel
self.cell_list = []
# check if these convolutions are changing their weights
def forward(self, image):
# print("starting forward")
# if(self.ch==64):
# print self.input_to_state.weight[0][0][0][0]
size = image.size()
b = size[0]
indvs = list(image.split(1,0))
tensor_array = []
for i in range(b):
tensor_array.append(self.RowLSTM(indvs[i]))
seq=tuple(tensor_array)
trans = torch.cat(seq,0)
global total
total+=1
# print("finished forward")
return trans.cuda()
def RowLSTM(self, image):
# input-to-state (K_is * x_i) : 1x3 convolution. generate h x n x n tensor. hxnxn tensor contains all i -> s info
# print("Starting LSTM")
self.cell_list=[]
igates = []
# print(image.size())
n = image.size()[2]
ch=image.size()[1]
for i in range(n):
if i==0:
# COULD BE THIS AREA, SHOULD EVERYTHING BE 0?
isgates = self.splitIS(self.input_to_state(image)) # convolve, then split into gates (4 per row)
#print("{} and {}".format(len(isgates),len(isgates[0])))
#cell=RowLSTMCell(0,torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
cell=RowLSTMCell(0,isgates[0][0],isgates[0][1],isgates[0][2],isgates[0][3],torch.zeros(ch,n,1).cuda(),torch.zeros(ch,n,1).cuda())
cell.c=isgates[0][0]*isgates[0][3]
cell.h=torch.tanh(cell.c)*isgates[0][1]
# now have dummy variables for first row
self.cell_list.append(cell)
else:
cell_prev = self.cell_list[i-1]
hid_prev = cell_prev.getHiddenState()
ssgates = self.splitSS(self.state_to_state(hid_prev.unsqueeze(0)))
gates = self.addGates(isgates, ssgates,i)
ig, og, fg, gg = gates[0], gates[1], gates[2], gates[3]
cell = RowLSTMCell(cell_prev, ig, og, fg, gg, 0 ,0) #MORE zeros
cell.compute()
self.cell_list.append(cell)
# now have a list of all cell data, concatenate hidden state into 1 x h x n x n
hidden_layers = []
for i in range(n):
hid = self.cell_list[i].h
hidden_layers.append(torch.unsqueeze(hid,0))
seq = tuple(hidden_layers)
tensor = torch.cat(seq,3)
# print("finished lstm")
#print(tensor.size())
return tensor
def splitIS(self, tensor): #always going to be splitting into 4 pieces, so no need to add extra parameters
inputStateGates={}
size=tensor.size() # 1 x 4h x n x n
out_ft=size[1] # get 4h for the nxnx4h tensor
num=size[2] # get n for the nxn image
hh=out_ft/4 # we want to split the tensor into 4, for the gates
tensor = torch.squeeze(tensor).cuda() # 4h x n x n
# First, split by row: Creates n tensors of 4h x n x 1
rows = list(tensor.split(1,2))
for i in range(num):
# Each row is a tensor of 4h x n x 1, split it into 4 of h x n x 1
row=rows[i]
# print("Each row using cuda: "+str(row.is_cuda))
inputStateGates[i]=list(row.split(hh,0))
return inputStateGates
def splitSS(self, tensor): # 1 x 4h x n x 1, create 4 of 1 x h x n x 1
size=tensor.size()
out_ft=size[1] # get 4h for the 1x4hxn tensor
num=size[2] # get n for the 1xhxn row
hh=out_ft/4 # we want to split the tensor into 4, for the gates
tensor = tensor.squeeze(0).cuda() # 4h x n x 1
splitted=list(tensor.split(hh,0))
return splitted
def addGates(self, i2s,s2s,key):
""" these dictionaries are of form {key : [[i], [o], [f], [g]]}
we want to add pairwise elemeents """
# i2s is of form key: [[i], [o], [f], [g]] where each gate is hxn
# s2s is of form [[h,n],[h,n],[h,n], [h,n]]
gateSum = []
for i in range(4): # always of length 4, representing the gates
gateSum.append(torch.sigmoid(i2s[key][i] + s2s[i]))
return gateSum
这是每个单元格的类:
class RowLSTMCell(): #inherit torch.nn.LSTM?
def __init__(self,prev_row, i, o, f, g, c, h):
self.c=c
self.h=h
self.i=i
self.i = self.i.cuda()
self.o=o
self.o = self.o.cuda()
self.g=g
self.g = self.g.cuda()
self.f=f
self.f = self.f.cuda()
self.prev_row=prev_row
def getStateSize(self):
return self._state_size
def getOutputSize(self):
return self._output_size
def compute(self):
c_prev = self.prev_row.getCellState()
h_prev = self.prev_row.getHiddenState()
self.c = self.f * c_prev + self.i * self.g
self.h = torch.tanh(self.c) * self.o
def getHiddenState(self):
return self.h
def getCellState(self):
return self.c