我正在构建seq2seq模型,并使用检查点保存编码器和解码器:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
checkpoint.save(file_prefix = checkpoint_prefix)
但是,当我尝试从检查点目录还原时,它似乎不起作用。
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
该如何解决?谢谢〜