我正在尝试将Pytorch模型手动转换为Tensorflow进行部署。 ONNX似乎并没有从Pytorch LSTM过渡到Tensorflow CuDNNLSTM,因此这就是我手工编写的原因。
我尝试了以下代码: 它在运行Python 2.7,Pytorch 1.0,tensorflow 1.12,cuda9的Anaconda环境中运行。我在Pytorch层中遵循批处理规范运行时没有任何偏差,但是由于Keras不提供该选项,因此我只是分配了0偏差。
model = load_model('my_model.hdf5')
checkpoint = ModelCheckpoint(cp_filepath, monitor='acc',
verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
bs=32 #batch size
epoch count=52
cur_epochs=5
model.fit(
training_set,
steps_per_epoch=len(training_set)//bs,
inital_epoch=epoch_count,
epochs=cur_epochs+epoch_count,
validation_data=test_set,
validation_steps=len(test_set)//bs,
callbacks=callbacks_list,
shuffle=True,
verbose=1
)
import torch
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import CuDNNLSTM, Bidirectional
from tensorflow.keras.models import Sequential, Model
input_size = 80
hidden_size = 512
with torch.no_grad():
rnn1 = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True, bias=False, batch_first=True).cuda()
model = Sequential()
model.add(Bidirectional(CuDNNLSTM(hidden_size, return_sequences=True), input_shape=(None, input_size), name='rnn'))
bias_size = rnn1.weight_hh_l0.detach().cpu().numpy().T.shape[1] * 2
keras_format_weights = [
rnn1.weight_ih_l0.detach().cpu().numpy().T,
rnn1.weight_hh_l0.detach().cpu().numpy().T,
np.zeros(bias_size,),
rnn1.weight_ih_l0_reverse.detach().cpu().numpy().T,
rnn1.weight_hh_l0_reverse.detach().cpu().numpy().T,
np.zeros(bias_size,),
]
model.layers[0].set_weights(keras_format_weights)
random_test = np.random.rand(1, 1, 80)
res1, _ = rnn1.forward(torch.FloatTensor(random_test).cuda())
res1 = res1.detach().cpu().numpy()
res2 = model.predict(random_test)
print(np.allclose(res1, res2, atol=1e-2))
print(res1)
print(res2)
现在,这确实适用于通用Keras LSTM:
False
[[[ 0.01265562 0.07478553 0.0470101 ... -0.02260824 0.0243004
-0.0261014 ]]]
[[[-0.05316251 -0.00230848 0.03070898 ... 0.01497027 0.00976444
-0.01095549]]]
但是我需要CuDNNLSTM的速度优势,而且Pytorch还是使用相同的后端。
答案 0 :(得分:0)
更新:解决方案是将火炬模型转换为基于Keras的LSTM模型,然后调用
ValueError: Array must be all same time zone
During handling of the above exception, another exception occurred:
ValueError: Tz-aware datetime.datetime cannot be converted to datetime64 unless utc=True