Tensorflow - 保存和恢复模型

时间:2017-03-27 00:04:31

标签: python machine-learning tensorflow neural-network conv-neural-network

我在Stackoverflow中遇到了this question,它显示了如何保存和恢复模型。

我的问题是如何在我的代码中执行此操作,因为我不确定如何将其与我的代码集成:

  #user_controller.rb
  def update
    UserManagement::UserUpdatePassword.new(@user, user_params).call
  end

  def user_params
    params.require(:user).permit(:password, :password_confirmation)
  end

  def get_user
    @user ||= User.find_by_password_reset_token(params[:id])
  end

感谢。

1 个答案:

答案 0 :(得分:0)

以下是我过去用于恢复的一些示例代码。这应该在会话创建之后但在运行模型之前完成。

saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    i_stopped = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])
else:
    print('No checkpoint file found!')
    i_stopped = 0

对于保存,每1000批次,或者在您的情况下,您可以保存每个时期:

if i % 1000 == 0:
    checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=i)

在代码中实现它应该相当简单。请记住,您必须定义将保存模型的检查点目录。

希望这有帮助!

相关问题