Tensorflow:检查点仅在第一个时期保存

时间:2019-07-12 20:27:01

标签: python-3.x tensorflow

我正在训练NMT模型并使用下面的代码片段保存检查点。但是,只为第一个时期保存检查点,而不为其余的保存。奇怪的是,它通过保存小的数据集(如10至50行)的所有历元而按预期工作。我测试了50000行,它不起作用。不知道我在这里想念什么。

checkpoint_dir = './training_checkpoints_testingsave'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)


EPOCHS = 10

for epoch in range(EPOCHS):
    start = time.time()

    hidden = encoder.initialize_hidden_state()
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset):
        loss = 0

        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(inp, hidden)

            dec_hidden = enc_hidden

            dec_input = tf.expand_dims([targ_lang.word2idx['<start>']] * BATCH_SIZE, 1)       

        # Teacher forcing - feeding the target as the next input
            for t in range(1, targ.shape[1]):
            # passing enc_output to the decoder
                 predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)

                loss += loss_function(targ[:, t], predictions)

            # using teacher forcing
                 dec_input = tf.expand_dims(targ[:, t], 1)

    batch_loss = (loss / int(targ.shape[1]))

    total_loss += batch_loss

    variables = encoder.variables + decoder.variables

    gradients = tape.gradient(loss, variables)

    optimizer.apply_gradients(zip(gradients, variables))

    if batch % 100 == 0:
        print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,
                                                     batch,

                                           batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
  if (epoch + 1) % 2 == 0:
     checkpoint.save(file_prefix = checkpoint_prefix)

   print('Epoch {} Loss {:.4f}'.format(epoch + 1,
                                    total_loss / N_BATCH))

0 个答案:

没有答案