这是pytorch lstmcell的示例:
rnn = nn.LSTMCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
cx = torch.randn(3, 20)
output = []
hx, cx = rnn(input[0], (hx, cx))
output.append(hx)
不确定如何将其转换为keras lstm / lstmcell
答案 0 :(得分:0)
原始Pytorch代码:
self.att_lstm = nn.LSTMCell(1536,512)
h_att,c_att = self.att_lstm(att_lstm_input,(state [0] [0],state [1] [0]))
状态[0] [0],状态[1] [0]是张量(10,512)
我在喀拉拉邦尝试过的东西:
inputs = Input(shape=(10, 1536))
lstm, h_att, c_att = LSTM(units=512, input_shape=(10,1536), name='core.att_lstm', return_state=True)(inputs)
所以我不确定是否正确。