BeatConvModel
是我训练过的模型的名称,这些是模型中涉及的一些层。
BeatConvModel(
(layers): Sequential(
(0): Dropout(p=0.5, inplace=False)
(1): Sequential(
(0): Conv1d(1, 128, kernel_size=(55,), stride=(1,), padding=(27,), bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_mode=False)
(3): Dropout(p=0.5, inplace=False)
(4): Sequential(
(0): Conv1d(128, 128, kernel_size=(25,), stride=(1,), padding=(12,), bias=False)
(1): ReLU(inplace=True)
(2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
我已经保存了权重并像这样导出模型:
learn.save('ECG_model')
learn.export()
但是当我尝试在应用程序中部署模型时,会出现此错误。
AttributeError Traceback (most recent call last)
<ipython-input-3-b84658d87aa7> in <module>
2 ECG = np.loadtxt('text-file.txt')
3 path = Path('pth-of-the-.pkl-file')
----> 4 BeatConvModel = load_learner(path, 'export.pkl')
5 pred_class,pred_idx,outputs = learn.predict(ECG)
~/anaconda3/envs/fastai/lib/python3.7/site-packages/fastai/basic_train.py in
load_learner(path, file, test, tfm_y, **db_kwargs)
614 "Load a `Learner` object saved with `export_state` in `path/file` with empty data,
optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
615 source = Path(path)/file if is_pathlike(file) else file
--> 616 state = torch.load(source, map_location='cpu') if defaults.device ==
torch.device('cpu') else torch.load(source)
617 model = state.pop('model')
618 src = LabelLists.load_state(path, state.pop('data'))
~/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
527 with _open_zipfile_reader(f) as opened_zipfile:
528 return _load(opened_zipfile, map_location, pickle_module,
**pickle_load_args)
--> 529 return _legacy_load(opened_file, map_location, pickle_module,
**pickle_load_args)
530
531
~/anaconda3/envs/fastai/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
700 unpickler = pickle_module.Unpickler(f, **pickle_load_args)
701 unpickler.persistent_load = persistent_load
--> 702 result = unpickler.load()
703
704 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
AttributeError: Can't get attribute 'BeatConvModel' on <module '__main__'>
我已经部署了这样的模型:
ECG = np.loadtxt('text-file.txt')
path = Path('pth-of-the-.pkl-file')
learner = load_learner(path, 'export.pkl')
pred_class,pred_idx,outputs = learn.predict(ECG)
print(pred_class)
print(pred_idx)
print(outputs)