Tensorflow:成功恢复检查点后丢失重置

时间:2016-12-26 09:03:13

标签: tensorflow reset restore checkpoint loss

保存或恢复时没有错误。权重似乎已正确恢复。

我正在尝试按照karpathy/min-char-rnn.pysherjilozair/char-rnn-tensorflowTensorflow RNN tutorial建立我自己的最小字符级RNN。我的脚本似乎按预期工作,除非我尝试恢复/恢复培训。

如果我重新启动脚本并从检查点恢复然后恢复训练,则丢失总是会恢复,就像没有检查点一样(尽管权重已正确恢复)。 但是,在脚本的执行过程中,如果我重置图表,启动新会话并恢复,那么我可以按预期继续最小化损失。

我试图在桌面(使用GPU)和笔记本电脑(仅限CPU)上运行此操作,两者都在Windows上使用Tensorflow 0.12。

下面是我的代码,我在这里上传了代码+数据+控制台输出: https://gist.github.com/dk1027/777c3da7ba1ff7739b5f5e89491bef73

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import rnn_cell

class model_input:

    def __init__(self,data_path, batch_size, steps):
        self.batch_idx = 0
        self.data_path = data_path
        self.steps = steps
        self.batch_size = batch_size
        data = open(self.data_path).read()
        data_size = len(data)
        self.vocab = set(data)
        self.vocab_size = len(self.vocab)
        self.vocab_to_idx = {v:i for i,v in enumerate(self.vocab)}
        self.idx_to_vocab = {i:v for i,v in enumerate(self.vocab)}
        c = self.batch_size * self.steps
        #Offset by 1 character because we want to predict the next character
        _data_as_idx = np.asarray([self.vocab_to_idx[v] for v in data], dtype=np.int32)
        self.X = _data_as_idx[:-1]
        self.Y = _data_as_idx[1:]

    def reset(self):
        self.batch_idx = 0

    def next_batch2(self):
        i = self.batch_idx
        j = self.batch_idx + self.batch_size * self.steps

        if j >= self.X.shape[0]:
            i = 0
            j = self.batch_size * self.steps
            self.batch_idx = 0

        #print("next_batch: (%s,%s)" %(i,j))
        x = self.X[i:j]
        x = x.reshape(-1,self.steps)

        _xlen = x.shape[0]
        _y = self.Y[i:j]
        _y = _y.reshape(-1,self.steps)
        self.batch_idx += 1

        return x, _y

    def toIdx(self, s):
        res = []
        for _s in s:
            res.append(self.vocab_to_idx[_s])
        return res

    def toStr(self, idx):
        s = ''
        for i in idx:
            s += self.idx_to_vocab[i]
        return s

class Config():
    def __init__(self):
        # Parameters
        self.learning_rate = 0.001
        self.training_iters = 10000
        self.batch_size = 20
        self.display_step = 200
        self.max_epoch = 1
        # Network Parameters
        self.n_input = 1 # 1 character input
        self.n_steps = 25 # sequence length
        self.n_hidden = 128 # hidden layer num of features
        self.n_rnn_layers = 2
        # To be set later
        self.vocab_size = None

# Train
def Train(sess, model, data, config, saver):
    init_state = sess.run(model.initial_state)
    data.reset()
    epoch = 0
    while epoch < config.max_epoch:
        # Keep training until reach max iterations
        step = 0
        while step * config.batch_size < config.training_iters:
            # Run optimization op (backprop)
            fetch_dict = {
                "cost": model.cost,
                "final_state": model.final_state,
                "op" : model.train_op
            }
            feed_dict = {}
            for i, (c, h) in enumerate(model.initial_state):
                feed_dict[c] = init_state[i].c
                feed_dict[h] = init_state[i].h
            batch_x, batch_y = data.next_batch2()
            feed_dict[model.x]=batch_x
            feed_dict[model.y]=batch_y
            fetches = sess.run(fetch_dict, feed_dict=feed_dict)

            if (step % config.display_step) == 0:
                print("Iter " + str(step*config.batch_size) + ", Minibatch Loss={:.7f}".format(fetches["cost"]))
            step += 1
            if (step*config.batch_size % 5000) == 0:
                sp = saver.save(sess, config.save_path + "model.ckpt", global_step = step * config.batch_size + epoch * config.training_iters)
                print("Saved to %s" % sp)
        sp = saver.save(sess, config.save_path + "model.ckpt", global_step = step * config.batch_size + epoch * config.training_iters)
        print("Saved to %s" % sp)
        epoch += 1

    print("Optimization Finished!")


class Model():
    def __init__(self, config):
        self.config = config

        lstm_cell = rnn_cell.BasicLSTMCell(config.n_hidden, state_is_tuple=True)

        self.cell = rnn_cell.MultiRNNCell([lstm_cell] * config.n_rnn_layers, state_is_tuple=True)

        self.x = tf.placeholder(tf.int32, [config.batch_size, config.n_steps])
        self.y = tf.placeholder(tf.int32, [config.batch_size, config.n_steps]) 
        self.initial_state = self.cell.zero_state(config.batch_size, tf.float32)

        with tf.device("/cpu:0"):
            embedding = tf.get_variable("embedding", [config.vocab_size, config.n_hidden], dtype=tf.float32)
            inputs = tf.nn.embedding_lookup(embedding, self.x)
        outputs = []
        state = self.initial_state
        with tf.variable_scope('rnn'):
            softmax_w = tf.get_variable("softmax_w", [config.n_hidden, config.vocab_size])
            softmax_b = tf.get_variable("softmax_b", [config.vocab_size])

            for time_step in range(config.n_steps):
                if time_step > 0: tf.get_variable_scope().reuse_variables()
                (cell_output, state) = self.cell(inputs[:, time_step, :], state)
                outputs.append(cell_output)

        output = tf.reshape(tf.concat(1, outputs), [-1, config.n_hidden])
        self.logits = tf.matmul(output, softmax_w) + softmax_b
        loss = tf.nn.seq2seq.sequence_loss_by_example(
            [self.logits],
            [self.y],
            [tf.ones([config.batch_size * config.n_steps], dtype=tf.float32)],
            name="seq2seq")

        self.cost = tf.reduce_sum(loss) / config.batch_size
        self.final_state = state

        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),5)
        optimizer = tf.train.AdamOptimizer(config.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def main():
    # Read input data
    data_path = "1sonnet.txt"
    save_path = "./save/"

    config = Config()
    data = model_input(data_path, config.batch_size, config.n_steps)
    config.vocab_size = data.vocab_size
    config.data_path = data_path
    config.save_path = save_path

    train_model = Model(config)
    print("Model defined.")

    bReproProblem = True
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(save_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("restored from %s" % ckpt.model_checkpoint_path)

        Train(sess, train_model, data, config, saver)


    if bReproProblem:
        tf.reset_default_graph() #reset everything
        data.reset()
        train_model2 = Model(config)
        print("Starting a new session, restore from checkpoint, and train again")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver2 = tf.train.Saver()
            ckpt = tf.train.get_checkpoint_state(save_path)
            if ckpt and ckpt.model_checkpoint_path:
                saver2.restore(sess, ckpt.model_checkpoint_path)
                print("restored from %s" % ckpt.model_checkpoint_path)

            Train(sess, train_model2, data, config, saver2)


if __name__ == '__main__':
    main()

1 个答案:

答案 0 :(得分:0)

TL; DR

请确保每次运行代码时标签都相同,尤其是对于那些使用列表索引作为标签的用户。

有关详细信息,请参见this question

如果将列表索引用作标签,请对数据进行排序或将索引保存到磁盘。使用:

labels = sorted(set(data))

代替

labels = set(data))

一般建议

在Python实现中,有些方法(例如set()os.listdir())会返回未排序的集合。换句话说,每次运行项目的索引可能会有所不同。

对于set(),Python use a random method构建set。对于os.listdir()it doesn't promise the order of the returned list。因此,对于健壮的代码,建议对数据集使用sorted()

您的问题

data_size = len(data)
self.vocab = set(data)
self.vocab_size = len(self.vocab)
self.vocab_to_idx = {v:i for i,v in enumerate(self.vocab)}
self.idx_to_vocab = {i:v for i,v in enumerate(self.vocab)}

这可能是由您构建标签的方式引起的。每次您运行代码时,vocab_to_idx可能会有所不同。

只需添加一个sorted()

data_size = len(data)
self.vocab = sorted(set(data))
self.vocab_size = len(self.vocab)
self.vocab_to_idx = {v:i for i,v in enumerate(self.vocab)}
self.idx_to_vocab = {i:v for i,v in enumerate(self.vocab)}