feat: graceful shutdown on SIGINT/SIGTERM

main now uses signal.NotifyContext to cancel a root context on
SIGINT or SIGTERM. Both the public and debug listeners share that
context; on cancellation each calls srv.Shutdown with a 30s
timeout to drain in-flight requests before exiting.

If either listener fails the other is signalled to shut down too,
and the process exits 1. Bad CLI args now exit 2 (was 0).

Adds TestGracefulShutdown which fires a slow request, cancels the
context mid-flight, and asserts the request completes successfully
and the server returns nil.
This commit is contained in:
ns 2026-05-07 07:21:07 +00:00
commit abbdb05709
No known key found for this signature in database
GPG key ID: 69784C31D818C1A1
3 changed files with 170 additions and 13 deletions

View file

@ -1,11 +1,14 @@
package main
import (
"context"
"flag"
"log"
"strings"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"github.com/mpolden/echoip/http"
"github.com/mpolden/echoip/iputil"
@ -44,7 +47,7 @@ func main() {
flag.Parse()
if len(flag.Args()) != 0 {
flag.Usage()
return
os.Exit(2)
}
r, err := geo.Open(*countryFile, *cityFile, *asnFile)
@ -82,16 +85,49 @@ func main() {
if *cacheSize > 0 {
log.Printf("Cache capacity set to %d", *cacheSize)
}
// signal.NotifyContext cancels ctx on SIGINT/SIGTERM, triggering
// graceful shutdown of all running listeners.
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
var wg sync.WaitGroup
errCh := make(chan error, 2)
if *profileAddr != "" {
log.Printf("Enabling debug/profiling handlers on http://%s (do not expose publicly)", *profileAddr)
wg.Add(1)
go func(addr string) {
if err := server.ListenAndServeDebug(addr); err != nil {
log.Fatalf("debug listener on %s failed: %s", addr, err)
defer wg.Done()
if err := server.ListenAndServeDebug(ctx, addr); err != nil {
errCh <- err
stop() // trigger shutdown of the public listener too
}
}(*profileAddr)
}
log.Printf("Listening on http://%s", *listen)
if err := server.ListenAndServe(*listen); err != nil {
log.Fatal(err)
wg.Add(1)
go func() {
defer wg.Done()
log.Printf("Listening on http://%s", *listen)
if err := server.ListenAndServe(ctx, *listen); err != nil {
errCh <- err
stop()
}
}()
<-ctx.Done()
log.Println("Shutdown signal received, draining in-flight requests...")
wg.Wait()
close(errCh)
var failed bool
for err := range errCh {
log.Printf("listener error: %s", err)
failed = true
}
if failed {
os.Exit(1)
}
log.Println("Shutdown complete")
}

View file

@ -1,6 +1,7 @@
package http
import (
"context"
"encoding/json"
"fmt"
"html/template"
@ -496,15 +497,50 @@ func newServer(addr string, h http.Handler) *http.Server {
}
}
func (s *Server) ListenAndServe(addr string) error {
return newServer(addr, s.Handler()).ListenAndServe()
// shutdownTimeout is how long an HTTP server is given to drain in-flight
// requests once shutdown has been initiated.
const shutdownTimeout = 30 * time.Second
// listenAndServe runs srv until either it returns an error or ctx is
// cancelled. On cancellation, it calls Shutdown with shutdownTimeout to
// drain in-flight requests. Returns nil on a clean shutdown.
func listenAndServe(ctx context.Context, srv *http.Server) error {
errCh := make(chan error, 1)
go func() {
err := srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
errCh <- err
return
}
errCh <- nil
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("graceful shutdown failed: %w", err)
}
<-errCh
return nil
}
}
// ListenAndServe starts the public HTTP server on addr. It blocks until ctx
// is cancelled or the server fails. On cancellation, in-flight requests are
// given shutdownTimeout to complete.
func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
return listenAndServe(ctx, newServer(addr, s.Handler()))
}
// ListenAndServeDebug starts an HTTP server bound to addr that serves the
// debug handler. The caller is responsible for ensuring addr is not reachable
// from untrusted networks.
func (s *Server) ListenAndServeDebug(addr string) error {
return newServer(addr, s.DebugHandler()).ListenAndServe()
// from untrusted networks. It blocks until ctx is cancelled or the server
// fails.
func (s *Server) ListenAndServeDebug(ctx context.Context, addr string) error {
return listenAndServe(ctx, newServer(addr, s.DebugHandler()))
}
func formatCoordinate(c float64) string {

View file

@ -1,14 +1,17 @@
package http
import (
"context"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"strings"
"testing"
"time"
"github.com/mpolden/echoip/iputil/geo"
)
@ -282,3 +285,85 @@ func TestCLIMatcher(t *testing.T) {
}
}
}
func TestGracefulShutdown(t *testing.T) {
// Verify ListenAndServe drains an in-flight request when ctx is cancelled.
srv := testServer()
srv.Handler() // ensure routes registered
// Bind to an ephemeral port so the test is hermetic.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := ln.Addr().String()
ln.Close()
// Inject a slow handler via a custom mux wrapping the real one.
slowReady := make(chan struct{})
slowDone := make(chan struct{})
mux := http.NewServeMux()
mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) {
close(slowReady)
time.Sleep(200 * time.Millisecond)
w.Write([]byte("ok"))
close(slowDone)
})
httpSrv := newServer(addr, mux)
ctx, cancel := context.WithCancel(context.Background())
serveErr := make(chan error, 1)
go func() { serveErr <- listenAndServe(ctx, httpSrv) }()
// Wait for the listener to be up by retrying a short connection.
deadline := time.Now().Add(2 * time.Second)
for {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
c.Close()
break
}
if time.Now().After(deadline) {
t.Fatalf("server never came up: %v", err)
}
time.Sleep(10 * time.Millisecond)
}
// Fire the slow request, wait until handler started, then cancel.
respCh := make(chan *http.Response, 1)
respErr := make(chan error, 1)
go func() {
resp, err := http.Get("http://" + addr + "/slow")
if err != nil {
respErr <- err
return
}
respCh <- resp
}()
<-slowReady
cancel()
// Shutdown must wait for the in-flight request to finish.
select {
case resp := <-respCh:
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
if string(body) != "ok" {
t.Fatalf("in-flight request lost: got %q", body)
}
case err := <-respErr:
t.Fatalf("in-flight request failed during shutdown: %v", err)
case <-time.After(2 * time.Second):
t.Fatal("in-flight request did not complete within 2s")
}
select {
case <-slowDone:
default:
t.Fatal("slow handler did not finish")
}
if err := <-serveErr; err != nil {
t.Fatalf("listenAndServe returned error: %v", err)
}
}