Tensorflow中恢复模型的准确性未得到改善

时间:2017-12-07 08:47:34

标签: python tensorflow

我正在运行LSTM网络进行语言建模,我使用tf.Supervisor来保存和恢复会话。

在每个纪元之后,我打印出困惑值。我第一次恢复会话,我注意到困惑是如何低于前一次运行,但后来虽然我看到困惑在更进一步的时期内下降,当我恢复模型时,困惑与之后的困惑相同我第一次保存它。 我总是打印global_step,所以我确定我正在加载最新的检查点。

首次运行中的困惑值:

Global step 0
Epoch: 1
Train perplexity: 1053.873
Eval perplexity: 994.486
Epoch: 2
Train perplexity: 559.507
Eval perplexity: 803.345
Epoch: 3
Train perplexity: 377.886
Eval perplexity: 606.682
Epoch: 4
Train perplexity: 282.728
Eval perplexity: 472.485
Epoch: 5
Train perplexity: 229.564
Eval perplexity: 433.604

恢复后第二次运行:

Global step 830
Epoch: 1
Train perplexity: 394.555
Eval perplexity: 562.316
Epoch: 2
Train perplexity: 280.981
Eval perplexity: 440.451
Epoch: 3
Train perplexity: 226.292
Eval perplexity: 384.905
Epoch: 4
Train perplexity: 189.826
Eval perplexity: 340.012
Epoch: 5
Train perplexity: 166.766
Eval perplexity: 328.017

恢复后第三次运行:

Global step 1648
Epoch: 1
Train perplexity: 374.898
Eval perplexity: 508.347
Epoch: 2
Train perplexity: 271.804
Eval perplexity: 419.742
Epoch: 3
Train perplexity: 224.735
Eval perplexity: 367.012
Epoch: 4
Train perplexity: 192.667
Eval perplexity: 336.119
Epoch: 5
Train perplexity: 170.210
Eval perplexity: 303.626

所有进一步的运行都会提供类似的结果,global_step继续上升,但困惑值保持不变。

下面是我开始会话的代码片段。

def run_epoch(session: tf.Session, model, is_train=False, verbose=False):
    costs = 0
    costs_list = []
    iters = 0
    start_time = time.time()
    fetches = {'cost': model.cost,
               'final_state': model.final_state,
               'outputs': model.outputs,
               'states': model.states
               }

    if is_train:
        fetches.update({'train_op': model._train_op})

    state = session.run(model.initial_state)
    for step in range(model.input.epoch_size):
        feed_dict = {}
        for i, (c, h) in enumerate(model.initial_state):
            feed_dict[c] = state[i].c
            feed_dict[h] = state[i].h

        vals = session.run(fetches, feed_dict)

        cost = vals["cost"]
        state = vals["final_state"]

        costs += cost
        costs_list.append(cost)
        iters += model.input.num_steps

        if verbose and step % (model.input.epoch_size // 10) == 10:
            print('mean_cost:', np.array(costs_list).mean(), 'costs:', costs, 'iters:', iters)
            print("%.3f perplexity: %.3f speed: %.0f wps" %
                  (step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
                  iters * model.input.batch_size * max(1, 1) /
                  (time.time() - start_time)))
            costs_list = []
    return np.exp(costs / iters)


def train_model():
    config = ModelConfig()
    raw_data, id_to_word = embed_to_vocab(config, TRAIN_SOURCE)
    word_to_id = {w: i for i, w in id_to_word.items()}
    raw_eval_data, id_to_word_eval = embed_to_vocab(config, EVAL_SOURCE, word_to_id=word_to_id)


    with tf.Graph().as_default():

        with tf.name_scope('train'):
            with tf.variable_scope('Model'):
                data_input = LstmInput(config, raw_data)
                model = LstmNet(config, data_input)

        print('train model created')

        with tf.name_scope('eval'):
            with tf.variable_scope('Model', reuse=True):
                eval_input = LstmInput(config, raw_eval_data)
                eval_model = LstmNet(config, eval_input, is_training=False)

        print('eval model created')

        config_proto = tf.ConfigProto(allow_soft_placement=False)
        sv = tf.train.Supervisor(logdir=SAVE_PATH, save_model_secs=7)

        with sv.managed_session(config=config_proto) as session:
            print('Global step', sv.global_step.eval(session=session))
            for i in range(EPOCHS):
                print('Epoch: %s' % (i + 1))
                train_perplexity = run_epoch(session, model, verbose=False, is_train=True)
                print('Train perplexity: {:.3f}'.format(train_perplexity))
                eval_perplexity = run_epoch(session, eval_model)
                print('Eval perplexity: {:.3f}'.format(eval_perplexity))
            print(sv.global_step.eval(session=session))

0 个答案:

没有答案