如何正确地将pytorch LSTM转换为keras CuDNNLSTM?

时间:2019-05-21 03:00:46

标签: python tensorflow keras pytorch

我正在尝试将Pytorch模型手动转换为Tensorflow进行部署。 ONNX似乎并没有从Pytorch LSTM过渡到Tensorflow CuDNNLSTM,因此这就是我手工编写的原因。

我尝试了以下代码: 它在运行Python 2.7,Pytorch 1.0,tensorflow 1.12,cuda9的Anaconda环境中运行。我在Pytorch层中遵循批处理规范运行时没有任何偏差,但是由于Keras不提供该选项,因此我只是分配了0偏差。

model = load_model('my_model.hdf5')
checkpoint = ModelCheckpoint(cp_filepath, monitor='acc', 
verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

bs=32 #batch size
epoch count=52
cur_epochs=5
model.fit(
    training_set,
    steps_per_epoch=len(training_set)//bs,
    inital_epoch=epoch_count,
    epochs=cur_epochs+epoch_count,
    validation_data=test_set,
    validation_steps=len(test_set)//bs,
    callbacks=callbacks_list, 
    shuffle=True,
    verbose=1
    )
import torch
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import CuDNNLSTM, Bidirectional
from tensorflow.keras.models import Sequential, Model

input_size = 80
hidden_size = 512
with torch.no_grad():
    rnn1 = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True, bias=False, batch_first=True).cuda()

model = Sequential()
model.add(Bidirectional(CuDNNLSTM(hidden_size, return_sequences=True),  input_shape=(None, input_size), name='rnn'))

bias_size = rnn1.weight_hh_l0.detach().cpu().numpy().T.shape[1] * 2
keras_format_weights = [
                    rnn1.weight_ih_l0.detach().cpu().numpy().T,
                    rnn1.weight_hh_l0.detach().cpu().numpy().T,
                    np.zeros(bias_size,),
                    rnn1.weight_ih_l0_reverse.detach().cpu().numpy().T,
                    rnn1.weight_hh_l0_reverse.detach().cpu().numpy().T,
                    np.zeros(bias_size,),
                  ]


model.layers[0].set_weights(keras_format_weights)

random_test = np.random.rand(1, 1, 80)

res1, _ = rnn1.forward(torch.FloatTensor(random_test).cuda())
res1 = res1.detach().cpu().numpy()
res2 = model.predict(random_test)

print(np.allclose(res1, res2, atol=1e-2))
print(res1)
print(res2)

现在,这确实适用于通用Keras LSTM:

False
[[[ 0.01265562  0.07478553  0.0470101  ... -0.02260824  0.0243004
   -0.0261014 ]]]
[[[-0.05316251 -0.00230848  0.03070898 ...  0.01497027  0.00976444
   -0.01095549]]]

但是我需要CuDNNLSTM的速度优势,而且Pytorch还是使用相同的后端。

1 个答案:

答案 0 :(得分:0)

更新:解决方案是将火炬模型转换为基于Keras的LSTM模型,然后调用

ValueError: Array must be all same time zone

During handling of the above exception, another exception occurred:

ValueError: Tz-aware datetime.datetime cannot be converted to datetime64 unless utc=True