取消选择保存的pytorch模型会引发AttributeError:无法在<module'__main __'=“”尽管=“” using =“” class =“” definition =“” inline =“”

时间:2019-04-03 06:50:40

标签: python pickle pytorch

=“”

我正在尝试在Flask应用中提供pytorch模型。当我较早在jupyter笔记本上运行此代码时,此代码有效,但现在我在虚拟环境中运行此代码,即使类定义在那里,它显然也无法获得属性“ Net”。所有其他类似的问题都告诉我要在同一脚本中添加已保存模型的类定义。但这仍然行不通。火炬版本为1.0.1(在其中保存的模型以及virtualenv都经过训练) 我究竟做错了什么? 这是我的代码。

import os
import numpy as np
from flask import Flask, request, jsonify 
import requests

import torch
from torch import nn
from torch.nn import functional as F


MODEL_URL = 'https://storage.googleapis.com/judy-pytorch-model/classifier.pt'


r = requests.get(MODEL_URL)
file = open("model.pth", "wb")
file.write(r.content)
file.close()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)

        return F.log_softmax(x, dim=-1)

model = torch.load('model.pth')

app = Flask(__name__)

@app.route("/")
def hello():
    return "Binary classification example\n"

@app.route('/predict', methods=['GET'])
def predict():


    x_data = request.args['x_data']

    x_data =  x_data.split()
    x_data = list(map(float, x_data))

    sample = np.array(x_data) 

    sample_tensor = torch.from_numpy(sample).float()

    out = model(sample_tensor)

    _, predicted = torch.max(out.data, -1)

    if predicted.item() == 0: 
         pred_class = "Has no liver damage - ", predicted.item()
    elif predicted.item() == 1:
        pred_class = "Has liver damage - ", predicted.item()

    return jsonify(pred_class)

这是完整的追溯:

Traceback (most recent call last):
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask", line 10, in <module>
    sys.exit(main())
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 894, in main
    cli.main(args=args, prog_name=name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 557, in main
    return super(FlaskGroup, self).main(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/decorators.py", line 64, in new_func
    return ctx.invoke(f, obj, *args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 767, in run_command
    app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 293, in __init__
    self._load_unlocked()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 317, in _load_unlocked
    self._app = rv = self.loader()
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 372, in load_app
    app = locate_app(self, import_name, name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/flask/cli.py", line 235, in locate_app
    __import__(module_name)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/app.py", line 34, in <module>
    model = torch.load('model.pth')
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 368, in load
    return _load(f, map_location, pickle_module)
  File "/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/lib/python3.6/site-packages/torch/serialization.py", line 542, in _load
    result = unpickler.load()
AttributeError: Can't get attribute 'Net' on <module '__main__' from '/Users/judyraj/Judy/pytorch-deployment/flask_app/liver_disease_finder/bin/flask'>

This无法解决我的问题。我不想更改我保留模型的方式。在虚拟环境之外,torch.save()对我来说工作正常。我不介意将类定义添加到脚本中。尽管如此,我试图查看是什么导致了错误。

3 个答案:

答案 0 :(得分:1)

(这是部分答案)

我不认为torch.save(model,'model.pt')在命令提示符下还是从一个以'__main__'运行的脚本中保存模型并从另一个脚本加载模型时是无效的。

原因是割炬必须自动加载用于保存文件的模块,并且它从__name__获取模块名称。

现在为部分内容:目前尚不清楚如何解决此问题,尤其是当您混合使用virtualenvs时。

感谢Jatentaki朝此方向开始对话。

答案 1 :(得分:1)

首先我初始化了一个空模型,然后加载了保存的模型,这出于某种原因解决了这个问题。

答案 2 :(得分:-1)

这可能不是一个很受欢迎的答案,但是,我发现 dill 包在使我的代码工作方面非常一致。对我来说,我什至没有尝试加载模型,我正在尝试解压缩一个可以帮助我的东西的自定义对象,但由于某种原因无法找到它。我不知道为什么,但根据我的经验,莳萝似乎是更好的酸洗选择:

    # - path to files
    path = Path(path2dataset).expanduser()
    path2file_data_prep = Path(path2file_data_prep).expanduser()
    # - create dag dataprep obj
    print(f'path to data set {path=}')
    dag_prep = SplitDagDataPreparation(path)
    # - save data prep splits object
    print(f'saving to {path2file_data_prep=}')
    torch.save({'data_prep': dag_prep}, path2file_data_prep, pickle_module=dill)
    # - load the data prep splits object to test it loads correctly
    db = torch.load(path2file_data_prep, pickle_module=dill)
    db['data_prep']
    print(db)
    return path2file_data_prep