在训练期间读取数据子集时如何保存/恢复?

时间:2019-04-10 11:39:52

标签: python tensorflow

我想要一种有效的方法来保存训练过程(假定已读取一部分数据),然后加载新的数据子集并从保存的同一点继续进行训练。如何正确保存/加载模型来做到这一点?

我已经尝试了许多发现/保存模型的方法,但是由于所有更改都取决于版本,我感到困惑。

def train(self,sess,iter):
    start_time = time.clock()

    n_early_stop_epochs = 10  # Define it
    n_epochs = 30  # Define it

    # restore variables from previous train session
    if(iter>0): restore_variables(sess)

    # create saver object
    saver = tf.train.Saver(var_list = tf.trainable_variables(), max_to_keep = 4)

    early_stop_counter = 0

    # initialize train variables
    init_op = tf.group(tf.global_variables_initializer())

    sess.run(init_op)

    # assign a large value to min
    min_valid_loss = sys.float_info.max
    epoch=0

    # loop for a given number of epochs
    while (epoch < n_epochs): # max num epoch iteration
        epoch += 1
        epoch_start_time = time.clock()

        train_loss = self.train_epoch(sess)
        valid_loss = self.valid_epoch(sess)
        # print("valid ends")
        epoch_end_time=time.clock()
        if (epoch % 10 == 0):
            info_str ='Epoch='+str(epoch) + ', Train: ' + str(train_loss) + ', Valid: '
            info_str += str(valid_loss) + ', Time=' +str(epoch_end_time - epoch_start_time)
            print(info_str)

        if valid_loss < min_valid_loss:
            print('Best epoch=' + str(epoch))
            save_variables(sess, saver, epoch, self.model_id)
            min_valid_loss = valid_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1


        # stop training when overfiiting conditon is true
        if early_stop_counter > n_early_stop_epochs:
            # too many consecutive epochs without surpassing the best model
            print('stopping early')
            # self.kill = True
            break
    end_time=time.clock()
    print('Total time = ' + str(end_time - start_time))

我有一个for循环,用于将数据从磁盘读取到内存(读取一个子集),对这个子集数据进行训练, 我保留一个变量

  iter = 0 # counter, a number for current loop

要知道在除第一个循环外的所有其他循环中,我们首先加载模型。 一切似乎都正常,但是我不确定训练变量会从最后一个检查点继续更新还是从零开始重新开始。因为我的训练损失在几个时期内几乎为零,而验证损失并未从一个子集收敛到另一个子集,毕竟读取了数据,因此每个数据子集的验证损失可能取0.3到7的值。

例如数据 training_loss = 0.0004 第一个循环-拳头子集 validation_loss = 0.34 的多个时期, 在数据的下一次迭代中 training_loss = 0.0000 validation_loss = 5.2 等。 Validation_loss有高有低,不会收敛

0 个答案:

没有答案