有人可以向我解释一下吗?

时间:2021-08-01 12:45:00

标签: function variables deep-learning pytorch

我在 pytorch 中处理模型的每个人,我的代码如下:

def test_data(mdl):
    #Input new data
    age=float(input("What is the person's age? (18-90) "))
    sex=input("What is the person's sex? (Male/Female) ").capitalize()
    edx=input("What is the person's education level? (3-16)")
    ms=input("what is the person's martial status?")
    wcs=input("what is the person's workclass?")
    ocs=input("What is the person's occupation?")
    wrk_hrs=input("How many hours/week are worked?")
    
    #Preprocess the data
    sex_d={"Male":1,"Female":0}
    mar_d={"Married":1,"Single":0,"Civil-Partnership":2,"union":3,"Divorced":4,"Widowed":5}
    wrk_d = {'Federal-gov':0, 'Local-gov':1, 'Private':2, 'Self-emp':3, 'State-gov':4}
    occ_d = {'Adm-clerical':0, 'Craft-repair':1, 'Exec-managerial':2, 'Farming-fishing':3, 'Handlers-cleaners':4,
            'Machine-op-inspct':5, 'Other-service':6, 'Prof-specialty':7, 'Protective-serv':8, 'Sales':9, 
            'Tech-support':10, 'Transport-moving':11}
    sex=sex_d[sex]
    ms=mar_d[ms]
    wcs=wrk_d[wcs]
    ocs=occ_d[ocs]

cats=torch.tensor([sex,ms,wcs,ocs],dtype=torch.int64).reshape(1,-1)
conts=torch.tensor([wrk_hrs,age],dtype=torch.float32).reshape(1,-1)

model.eval()

with torch.no_grad():
    z=model(cats,conts).argmax().item()
print(f'\nThe predicted label is {z}')

test_data(model)

但我对这部分的作用感到困惑

    sex=sex_d[sex]
    ms=mar_d[ms]
    wcs=wrk_d[wcs]
    ocs=occ_d[ocs]

**我需要知道上面这部分执行了什么以及它是如何工作的,因为我不知道这部分代码在做什么。谁能告诉一下

1 个答案:

答案 0 :(得分:0)

Python 通过缩进来管理范围。您的缩进已损坏,您正试图引用 test_data 方法范围之外的变量。在此处了解有关 Python 范围的更多信息:https://www.w3schools.com/PYTHON/python_scope.asp

将您的代码更改为以下内容:

def test_data(mdl):
    #Input new data
    age=float(input("What is the person's age? (18-90) "))
    sex=input("What is the person's sex? (Male/Female) ").capitalize()
    edx=input("What is the person's education level? (3-16)")
    ms=input("what is the person's martial status?")
    wcs=input("what is the person's workclass?")
    ocs=input("What is the person's occupation?")
    wrk_hrs=input("How many hours/week are worked?")
    
    #Preprocess the data
    sex_d={"Male":1,"Female":0}
    mar_d={"Married":1,"Single":0,"Civil-Partnership":2,"union":3,"Divorced":4,"Widowed":5}
    wrk_d = {'Federal-gov':0, 'Local-gov':1, 'Private':2, 'Self-emp':3, 'State-gov':4}
    occ_d = {'Adm-clerical':0, 'Craft-repair':1, 'Exec-managerial':2, 'Farming-fishing':3, 'Handlers-cleaners':4,
            'Machine-op-inspct':5, 'Other-service':6, 'Prof-specialty':7, 'Protective-serv':8, 'Sales':9, 
            'Tech-support':10, 'Transport-moving':11}
    sex=sex_d[sex]
    ms=mar_d[ms]
    wcs=wrk_d[wcs]
    ocs=occ_d[ocs]

    cats=torch.tensor([sex,ms,wcs,ocs],dtype=torch.int64).reshape(1,-1)
    conts=torch.tensor([wrk_hrs,age],dtype=torch.float32).reshape(1,-1)

    model.eval()

    with torch.no_grad():
        z=model(cats,conts).argmax().item()
    print(f'\nThe predicted label is {z}')

    test_data(model)
相关问题