如何使用千层面继续训练预训练模型

时间:2017-10-25 16:11:16

标签: python-3.x lasagne pre-trained-model

我训练了一个拥有1000次迭代的网络,并希望在没有从头开始的情况下继续进行2000次迭代的训练。我读了这个问题的不同方法并编写了下面的代码,所以最后我的参数在'saved_pa​​rams'中。但是从现在开始我没有得到这些参数。

有人能解释我如何继续吗?如何在培训过程中获得这些参数?

from __future__ import print_function
import numpy as np
import theano
import lasagne
import pickle


input_var=None
ini = lasagne.init.HeUniform()

l_in = lasagne.layers.InputLayer(shape=(None, 1, 120, 120), input_var=input_var)
b= np.zeros((1, 4), dtype=theano.config.floatX)
b = b.flatten()

loc_l1 = lasagne.layers.MaxPool2DLayer(l_in, pool_size=(2, 2))
loc_l2 = lasagne.layers.Conv2DLayer(loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
loc_l3 = lasagne.layers.MaxPool2DLayer(loc_l2, pool_size=(2, 2))
loc_l4 = lasagne.layers.Conv2DLayer(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
loc_l5 = lasagne.layers.DenseLayer(loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
network = lasagne.layers.DenseLayer(loc_l5, num_units=4, b=b, W=lasagne.init.Constant(0.0), nonlinearity=lasagne.nonlinearities.identity)


def save_network(filename,param_values):
    f = open(filename, 'wb')
    pickle.dump(param_values,f,protocol=-1)
    f.close()

def load_network(filename):
    f = open(filename, 'rb')
    param_values = pickle.load(f)
    f.close()
    return param_values


save_network("model.npz",lasagne.layers.get_all_param_values(network))

saved_params = load_network("model.npz")
lasagne.layers.set_all_param_values(network, saved_params)

3 个答案:

答案 0 :(得分:0)

您可以使用加载和后调用方法或更改参数吗? 如果你想要一个图表,那么保存1000个纪元的错误

答案 1 :(得分:0)

if(load):
        net1 = Lenet(classes, num_epochs)
        net1.load_weights_from('Lenet.npz')
        network = net1
        train_X = np.float32(train_X)
        print("train_x",train_X)
        print("train_y",train_Y)
        train_Y = np.int16(train_Y)
        network = net1.fit(train_X, train_Y, num_epochs)
        print ("Loading weights successfully done.")

答案 2 :(得分:0)

这段代码只是一个例子。它正在做以下事情: 1.从上次训练中加载训练过的重量 2.使用相同的测试列车数据(否则您接受了测试数据培训) 3.启动网络的拟合方法(net_loaded.fit(参数)),该方法使用模型的加载权重

要从这个级联中获取图形,您必须将精度值保存在时间图上,或者用于显示组合结果的用途。