我有一个pytorch模型,大小为386MB,但是当我加载模型时
state = torch.load(f, flair.device)
我的GPU内存占用了900MB,为什么会发生这种情况,并且有解决方法?
这就是我保存模型的方式
model_state = self._get_state_dict()
# additional fields for model checkpointing
model_state["optimizer_state_dict"] = optimizer_state
model_state["scheduler_state_dict"] = scheduler_state
model_state["epoch"] = epoch
model_state["loss"] = loss
torch.save(model_state, str(model_file), pickle_protocol=4)
答案 0 :(得分:2)
可能是optimizer_state
占用了额外的空间。一些优化器(例如Adam)跟踪每个可训练参数的统计信息,例如一阶和二阶矩。如您所知,此信息会占用空间。
您可以先加载到CPU:
state = torch.load(f, map_location=torch.device('cpu'))