在客户端超时时终止服务器处理

时间:2021-03-03 06:35:06

标签: http go tcp timeout

我想知道是否有任何方法可以让 Go HTTP 服务器知道客户端超时,并立即终止正在进行的请求的处理。目前,我已经尝试在客户端设置超时,这些超时实际上在他们这边按预期工作,并且在达到超时后请求以 context deadline exceeded (Client.Timeout exceeded while awaiting headers) 结束。

    req, err := http.NewRequest(http.MethodGet, URL, nil)
    if err != nil {
        log.Fatal(err)
    }
    client := http.Client{Timeout: time.Second}
    _, err = client.Do(req)
    if err != nil {
        log.Fatal(err)
    }

我还尝试了不同版本的客户端代码,例如使用带有上下文的请求,并得到了相同的结果,这对客户端来说是可以的。

然而,当涉及到检测服务器端的超时时,结果发现请求的处理一直持续到服务器完成其工作,而不管客户端的超时时间,以及我希望发生的事情(我不知道是否有可能)是在客户端超时后立即终止并中止处理。

服务器端代码将是这样的(只是为了示例,在生产代码中它会更复杂):

func handler(w http.ResponseWriter, r *http.Request) {
    fmt.Println("before sleep")
    time.Sleep(3 * time.Second)
    fmt.Println("after sleep")

    fmt.Fprintf(w, "Done!")
}

func main() {
    http.HandleFunc("/", handler)
    log.Fatal(http.ListenAndServe(":8080", nil))
}

当前面的代码运行时,一个请求命中了 HTTP 服务器,会发生以下事件序列:

  1. 服务器打印 before sleep
  2. 服务器睡着了
  3. 客户端超时并终止,错误 context deadline exceeded (Client.Timeout exceeded while awaiting headers)
  4. 服务器唤醒并打印 after sleep

但我希望发生的是在第 3 步终止进程。

谢谢,我想知道您对此的看法,以及您认为我想做的事情是否可行。

1 个答案:

答案 0 :(得分:0)

这里有一些不同的想法。首先,为了确认您的要求,您似乎想让客户端断开连接触发整个服务器关闭。为此,您可以执行以下操作:

  1. 添加 context.WithCancelchannel 以用于传播关闭事件
  2. 注意 http 处理程序中的断开连接并取消上下文
  3. 添加一个在通道关闭时关闭服务器的 goroutine

这是一个完整的示例程序,它产生以下输出:

go run ./main.go
2021/03/04 17:56:44 client: starting request
2021/03/04 17:56:44 server: handler started
2021/03/04 17:56:45 client: deadline exceeded
2021/03/04 17:56:45 server: client request canceled
2021/03/04 17:56:45 server: performing server shutdown
2021/03/04 17:56:45 waiting for goroutines to finish
2021/03/04 17:56:45 All exited!
// main.go

package main

import (
    "context"
    "errors"
    "fmt"
    "io/ioutil"
    "log"
    "net/http"
    "os"
    "sync"
    "time"
)

func main() {
    wg := &sync.WaitGroup{}
    srvContext, srvCancel := context.WithCancel(context.Background())
    defer srvCancel()

    srv := http.Server{
        Addr: ":8000",
        Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            log.Printf("server: handler started")
            select {
            case <-time.After(2 * time.Second):
                log.Printf("server: completed long request")
                w.WriteHeader(http.StatusOK)
                w.Write([]byte("OK"))
            case <-r.Context().Done():
                log.Printf("server: client request canceled")
                srvCancel()
                return
            }
        }),
    }

    // add a goroutine that watches for the server context to be canceled
    // as a signal that it is time to stop the HTTP server.
    wg.Add(1)
    go func() {
        defer wg.Done()
        <-srvContext.Done()
        log.Printf("server: performing server shutdown")
        // optionally add a deadline context to avoid waiting too long
        if err := srv.Shutdown(context.TODO()); err != nil {
            log.Printf("server: shutdown failed with context")
        }
    }()

    // just simulate making the request after a brief delay
    wg.Add(1)
    go makeClientRequest(wg)

    if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
        fmt.Fprintf(os.Stderr, "Server failed listening with error: %v\n", err)
        return
    }

    log.Printf("waiting for goroutines to finish")
    wg.Wait()
    log.Printf("All exited!")
}

func makeClientRequest(wg *sync.WaitGroup) {
    defer wg.Done()
    // delay client request
    time.Sleep(500 * time.Millisecond)
    log.Printf("client: starting request")

    ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
    defer cancel()

    req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1:8000", http.NoBody)
    if err != nil {
        log.Fatalf("failed making client request")
    }
    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        if errors.Is(err, context.DeadlineExceeded) {
            log.Printf("client: deadline exceeded")
        } else {
            log.Printf("client: request error: %v", err)
        }
        return
    }

    // got a non-error response
    defer resp.Body.Close()
    body, _ := ioutil.ReadAll(resp.Body)
    log.Printf("client: got response %d %s", resp.StatusCode, string(body))
}