PyTorch:如何正确创建nn.Linear()

时间:2018-05-22 09:15:18

标签: python pytorch

我创建了一个将nn.Module作为子类的类。

在我的班上,我必须创建N个线性变换,其中N作为类参数给出。

因此我按照以下步骤进行:

    self.list_1 = []

    for i in range(N):
        self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias))

在forward方法中,我调用这些矩阵(使用list_1 [i])并连接结果。

两件事:

1)

即使我使用model.cuda(),这些线性变换也在cpu上使用,我得到以下错误:

  

RuntimeError:类型为Variable [torch.cuda.FloatTensor]的预期对象,但为参数#1找到类型Variable [torch.FloatTensor]'mat2'

我必须做

self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias).cuda())

如果相反,我不需要这样做:

self.nn = nn.Linear(self.x, 1, bias=mlp_bias)

然后直接使用self.nn。

2)

更明显的原因是,当我在我的主页中打印(模型)时,列表中的线性矩阵没有打印出来。

还有其他办法吗?也许用bmm?我发现它不太容易,我实际上想要分别得到我的N个结果。

提前谢谢你,

中号

1 个答案:

答案 0 :(得分:4)

您可以使用package org.avijit.Controller; import org.avijit.Entity.Student; import org.avijit.Service.StudentService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Controller; import org.springframework.ui.Model; import org.springframework.validation.BindingResult; import org.springframework.web.bind.annotation.*; import javax.validation.Valid; import java.util.List; @Controller public class StudentController { @Autowired StudentService studentService; @RequestMapping(value = "/logForm", method = RequestMethod.GET) public String gotoHome() { return "Login"; } @RequestMapping(value = "/loginCheck", method = RequestMethod.POST) public String checkLogin(@RequestParam String roll, @RequestParam String pass, Model model) { if (studentService.existsByRollAndPass(roll, pass)) { return "Welcome"; } else { model.addAttribute("logError", "logError"); return "Login"; } } @RequestMapping(value = "/registration") public String registration(Model model) { model.addAttribute(new Student()); return "Registration"; } @RequestMapping(value = "/getStudents") public String getStudents(Model model) { List<Student> studentList = studentService.getStudents(); model.addAttribute(studentList); return "StudentList"; } @RequestMapping(value = "/deleteStudent", method = RequestMethod.GET) public String deleteStudent(@RequestParam(name = "id") int id) { studentService.deleteStudent(id); return "redirect:/getStudents"; } @RequestMapping(value = "/editStudent/{id}", method = RequestMethod.GET) public String editStudent(@PathVariable("id") int id, Model model) { Student student = studentService.getStudent(id); model.addAttribute("student", student); return "StudentList"; } @RequestMapping(value = "/demo") public String demoRegistration(Model model) { model.addAttribute(new Student()); return "DemoRegistration"; } @RequestMapping(value = "/doRegistration", method = RequestMethod.POST) public String doRegistration(@Valid @ModelAttribute("student") Student student, BindingResult result, Model model) { if (result.hasErrors()) { model.addAttribute("hasError", true); return "DemoRegistration"; } else { if (student.getId() == null && !studentService.rollExist(student.getRoll())) { studentService.saveStudent(student); return "Welcome"; } else if (student.getId() == null && studentService.rollExist(student.getRoll())) { model.addAttribute("existRoll", "existRoll"); model.addAttribute("hasError", true); return "DemoRegistration"; } else { Student student1 = studentService.getStudent(student.getId()); if (student1.getId() != null && !student1.getRoll().equals(student.getRoll()) && studentService.rollExist(student.getRoll())) { model.addAttribute("hasError", "hasError"); model.addAttribute("existRoll", "existRoll"); return "redirect:/getStudents"; } else { student1.setFirstName(student.getFirstName()); student1.setLastName(student.getLastName()); student1.setRoll(student.getRoll()); student1.setAge(student.getAge()); student1.setPass(student.getPass()); studentService.saveStudent(student1); return "redirect:/getStudents"; } } } } } 来包装线性图层列表,如here

所述
nn.ModuleList