mirror of
https://github.com/thecodingrobot/echoip.git
synced 2026-05-15 11:07:07 +02:00
feat: graceful shutdown on SIGINT/SIGTERM
main now uses signal.NotifyContext to cancel a root context on SIGINT or SIGTERM. Both the public and debug listeners share that context; on cancellation each calls srv.Shutdown with a 30s timeout to drain in-flight requests before exiting. If either listener fails the other is signalled to shut down too, and the process exits 1. Bad CLI args now exit 2 (was 0). Adds TestGracefulShutdown which fires a slow request, cancels the context mid-flight, and asserts the request completes successfully and the server returns nil.
This commit is contained in:
parent
fe098d6357
commit
abbdb05709
3 changed files with 170 additions and 13 deletions
|
|
@ -1,11 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/mpolden/echoip/http"
|
||||
"github.com/mpolden/echoip/iputil"
|
||||
|
|
@ -44,7 +47,7 @@ func main() {
|
|||
flag.Parse()
|
||||
if len(flag.Args()) != 0 {
|
||||
flag.Usage()
|
||||
return
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
r, err := geo.Open(*countryFile, *cityFile, *asnFile)
|
||||
|
|
@ -82,16 +85,49 @@ func main() {
|
|||
if *cacheSize > 0 {
|
||||
log.Printf("Cache capacity set to %d", *cacheSize)
|
||||
}
|
||||
|
||||
// signal.NotifyContext cancels ctx on SIGINT/SIGTERM, triggering
|
||||
// graceful shutdown of all running listeners.
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
if *profileAddr != "" {
|
||||
log.Printf("Enabling debug/profiling handlers on http://%s (do not expose publicly)", *profileAddr)
|
||||
wg.Add(1)
|
||||
go func(addr string) {
|
||||
if err := server.ListenAndServeDebug(addr); err != nil {
|
||||
log.Fatalf("debug listener on %s failed: %s", addr, err)
|
||||
defer wg.Done()
|
||||
if err := server.ListenAndServeDebug(ctx, addr); err != nil {
|
||||
errCh <- err
|
||||
stop() // trigger shutdown of the public listener too
|
||||
}
|
||||
}(*profileAddr)
|
||||
}
|
||||
log.Printf("Listening on http://%s", *listen)
|
||||
if err := server.ListenAndServe(*listen); err != nil {
|
||||
log.Fatal(err)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Printf("Listening on http://%s", *listen)
|
||||
if err := server.ListenAndServe(ctx, *listen); err != nil {
|
||||
errCh <- err
|
||||
stop()
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
log.Println("Shutdown signal received, draining in-flight requests...")
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
var failed bool
|
||||
for err := range errCh {
|
||||
log.Printf("listener error: %s", err)
|
||||
failed = true
|
||||
}
|
||||
if failed {
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Println("Shutdown complete")
|
||||
}
|
||||
|
|
|
|||
46
http/http.go
46
http/http.go
|
|
@ -1,6 +1,7 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
|
|
@ -496,15 +497,50 @@ func newServer(addr string, h http.Handler) *http.Server {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
return newServer(addr, s.Handler()).ListenAndServe()
|
||||
// shutdownTimeout is how long an HTTP server is given to drain in-flight
|
||||
// requests once shutdown has been initiated.
|
||||
const shutdownTimeout = 30 * time.Second
|
||||
|
||||
// listenAndServe runs srv until either it returns an error or ctx is
|
||||
// cancelled. On cancellation, it calls Shutdown with shutdownTimeout to
|
||||
// drain in-flight requests. Returns nil on a clean shutdown.
|
||||
func listenAndServe(ctx context.Context, srv *http.Server) error {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
return fmt.Errorf("graceful shutdown failed: %w", err)
|
||||
}
|
||||
<-errCh
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServe starts the public HTTP server on addr. It blocks until ctx
|
||||
// is cancelled or the server fails. On cancellation, in-flight requests are
|
||||
// given shutdownTimeout to complete.
|
||||
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
|
||||
return listenAndServe(ctx, newServer(addr, s.Handler()))
|
||||
}
|
||||
|
||||
// ListenAndServeDebug starts an HTTP server bound to addr that serves the
|
||||
// debug handler. The caller is responsible for ensuring addr is not reachable
|
||||
// from untrusted networks.
|
||||
func (s *Server) ListenAndServeDebug(addr string) error {
|
||||
return newServer(addr, s.DebugHandler()).ListenAndServe()
|
||||
// from untrusted networks. It blocks until ctx is cancelled or the server
|
||||
// fails.
|
||||
func (s *Server) ListenAndServeDebug(ctx context.Context, addr string) error {
|
||||
return listenAndServe(ctx, newServer(addr, s.DebugHandler()))
|
||||
}
|
||||
|
||||
func formatCoordinate(c float64) string {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mpolden/echoip/iputil/geo"
|
||||
)
|
||||
|
|
@ -282,3 +285,85 @@ func TestCLIMatcher(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGracefulShutdown(t *testing.T) {
|
||||
// Verify ListenAndServe drains an in-flight request when ctx is cancelled.
|
||||
srv := testServer()
|
||||
srv.Handler() // ensure routes registered
|
||||
|
||||
// Bind to an ephemeral port so the test is hermetic.
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
ln.Close()
|
||||
|
||||
// Inject a slow handler via a custom mux wrapping the real one.
|
||||
slowReady := make(chan struct{})
|
||||
slowDone := make(chan struct{})
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) {
|
||||
close(slowReady)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
w.Write([]byte("ok"))
|
||||
close(slowDone)
|
||||
})
|
||||
|
||||
httpSrv := newServer(addr, mux)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
serveErr := make(chan error, 1)
|
||||
go func() { serveErr <- listenAndServe(ctx, httpSrv) }()
|
||||
|
||||
// Wait for the listener to be up by retrying a short connection.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("server never came up: %v", err)
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Fire the slow request, wait until handler started, then cancel.
|
||||
respCh := make(chan *http.Response, 1)
|
||||
respErr := make(chan error, 1)
|
||||
go func() {
|
||||
resp, err := http.Get("http://" + addr + "/slow")
|
||||
if err != nil {
|
||||
respErr <- err
|
||||
return
|
||||
}
|
||||
respCh <- resp
|
||||
}()
|
||||
<-slowReady
|
||||
cancel()
|
||||
|
||||
// Shutdown must wait for the in-flight request to finish.
|
||||
select {
|
||||
case resp := <-respCh:
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("in-flight request lost: got %q", body)
|
||||
}
|
||||
case err := <-respErr:
|
||||
t.Fatalf("in-flight request failed during shutdown: %v", err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("in-flight request did not complete within 2s")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-slowDone:
|
||||
default:
|
||||
t.Fatal("slow handler did not finish")
|
||||
}
|
||||
|
||||
if err := <-serveErr; err != nil {
|
||||
t.Fatalf("listenAndServe returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue