我想分离模型结构的创作和培训。模型作者设计模型结构,将未经训练的模型保存到文件中,然后将其发送给训练服务,以加载模型结构并训练模型。
Keras可以保存模型配置,然后load。
如何用PyTorch完成相同的工作?
答案 0 :(得分:1)
您可以编写自己的函数在PyTorch中进行操作。简单地执行torch.save(model.state_dict(), 'weightsAndBiases.pth')
即可节省权重。
要保存模型结构,可以执行以下操作:
(假设您有一个名为Network
的模型类,并且实例化了yourModel = Network()
)
model_structure = {'input_size': 784,
'output_size': 10,
'hidden_layers': [each.out_features for each in yourModel.hidden_layers],
'state_dict': yourModel.state_dict() #if you want to save the weights
}
torch.save(model_structure, 'model_structure.pth')
类似地,我们可以编写一个函数来加载结构。
def load_structure(filepath):
structure = torch.load(filepath)
model = Network(structure['input_size'],
structure['output_size'],
structure['hidden_layers'])
# model.load_state_dict(structure['state_dict']) if you had saved weights as well
return model
model = load_structure('model_structure.pth')
print(model)
编辑: 好的,上面是您可以访问您的类的源代码的情况,或者该类相对简单,因此您可以这样定义一个通用类:
class Network(nn.Module):
def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):
''' Builds a feedforward network with arbitrary hidden layers.
Arguments
---------
input_size: integer, size of the input layer
output_size: integer, size of the output layer
hidden_layers: list of integers, the sizes of the hidden layers
'''
super().__init__()
# Input to a hidden layer
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
# Add a variable number of more hidden layers
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
self.output = nn.Linear(hidden_layers[-1], output_size)
self.dropout = nn.Dropout(p=drop_p)
def forward(self, x):
''' Forward pass through the network, returns the output logits '''
for each in self.hidden_layers:
x = F.relu(each(x))
x = self.dropout(x)
x = self.output(x)
return F.log_softmax(x, dim=1)
但是,这仅适用于简单的情况,所以我认为这不是您想要的。
一种选择是,您可以在一个单独的.py文件中定义模型的体系结构,然后将其与其他必需项(如果模型体系结构很复杂)一起导入,或者可以在那时和那里完全定义模型。
另一种选择是将pytorch模型转换为onxx并保存。
另一种选择是,在Tensorflow中,您可以创建一个.pb
文件,该文件定义架构和模型的权重,而在Pytorch中,您可以通过以下方式进行操作:
torch.save(model, filepath)
这将保存模型对象本身,因为torch.save()只是一天结束时基于泡菜的保存。
model = torch.load(filepath)
但是这有局限性,例如,您的模型类定义可能不是可挑剔的(在某些复杂模型中可能)。 因为这是一个棘手的解决方法,所以通常会得到的答案是-不,您必须在加载经过训练的模型之前声明类定义,即您需要访问模型类源代码。
注意事项: 核心PyTorch开发人员之一针对没有代码加载pytorch模型的局限性给出了官方答案:
import foo
class MyModel(...):
def forward(input):
foo.bar(input)
此处,软件包foo未保存在模型检查点中。
鉴于这些限制,在没有原始源文件的情况下,没有可靠的方法来使torch.load工作。