os/signal: make NotifyContext cancel the context with a cause

This is especially useful when combined with the nesting semantics of
context.Cause, and with errgroup's use of CancelCauseFunc.

For example, with the following code

	ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
	defer stop()
	serveGroup, ctx := errgroup.WithContext(ctx)

calling context.Cause(ctx) after serveGroup.Wait() will return either
"interrupt signal received" (if that happens first) or the error from
serveGroup.

Change-Id: Ie181f5f84269f6e39defdad2d5fd8ead6a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/721700
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Mark Freeman <markfreeman@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Commit-Queue: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
This commit is contained in:
Filippo Valsorda 2025-11-18 17:19:04 +01:00 committed by Gopher Robot
parent ca37d24e0b
commit c1b7112af8
3 changed files with 37 additions and 11 deletions

View file

@ -0,0 +1,2 @@
[NotifyContext] now cancels the returned context with [context.CancelCauseFunc]
and an error indicating which signal was received.

View file

@ -272,11 +272,14 @@ func process(sig os.Signal) {
// the returned context. Future interrupts received will not trigger the default
// (exit) behavior until the returned stop function is called.
//
// If a signal causes the returned context to be canceled, calling
// [context.Cause] on it will return an error describing the signal.
//
// The stop function releases resources associated with it, so code should
// call stop as soon as the operations running in this Context complete and
// signals no longer need to be diverted to the context.
func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Context, stop context.CancelFunc) {
ctx, cancel := context.WithCancel(parent)
ctx, cancel := context.WithCancelCause(parent)
c := &signalCtx{
Context: ctx,
cancel: cancel,
@ -287,8 +290,8 @@ func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Co
if ctx.Err() == nil {
go func() {
select {
case <-c.ch:
c.cancel()
case s := <-c.ch:
c.cancel(signalError(s.String() + " signal received"))
case <-c.Done():
}
}()
@ -299,13 +302,13 @@ func NotifyContext(parent context.Context, signals ...os.Signal) (ctx context.Co
type signalCtx struct {
context.Context
cancel context.CancelFunc
cancel context.CancelCauseFunc
signals []os.Signal
ch chan os.Signal
}
func (c *signalCtx) stop() {
c.cancel()
c.cancel(nil)
Stop(c.ch)
}
@ -333,3 +336,9 @@ func (c *signalCtx) String() string {
buf = append(buf, ')')
return string(buf)
}
type signalError string
func (s signalError) Error() string {
return string(s)
}

View file

@ -9,6 +9,7 @@ package signal
import (
"bytes"
"context"
"errors"
"flag"
"fmt"
"internal/testenv"
@ -723,6 +724,9 @@ func TestNotifyContextNotifications(t *testing.T) {
}
wg.Wait()
<-ctx.Done()
if got, want := context.Cause(ctx).Error(), "interrupt signal received"; got != want {
t.Errorf("context.Cause(ctx) = %q, want %q", got, want)
}
fmt.Println("received SIGINT")
// Sleep to give time to simultaneous signals to reach the process.
// These signals must be ignored given stop() is not called on this code.
@ -797,11 +801,15 @@ func TestNotifyContextStop(t *testing.T) {
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
if got := context.Cause(c); got != context.Canceled {
t.Errorf("context.Cause(c.Err()) = %q, want %q", got, context.Canceled)
}
}
func TestNotifyContextCancelParent(t *testing.T) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
parent, cancelParent := context.WithCancelCause(context.Background())
parentCause := errors.New("parent canceled")
defer cancelParent(parentCause)
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()
@ -809,18 +817,22 @@ func TestNotifyContextCancelParent(t *testing.T) {
t.Errorf("c.String() = %q, want %q", got, want)
}
cancelParent()
cancelParent(parentCause)
<-c.Done()
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
if got := context.Cause(c); got != parentCause {
t.Errorf("context.Cause(c) = %q, want %q", got, parentCause)
}
}
func TestNotifyContextPrematureCancelParent(t *testing.T) {
parent, cancelParent := context.WithCancel(context.Background())
defer cancelParent()
parent, cancelParent := context.WithCancelCause(context.Background())
parentCause := errors.New("parent canceled")
defer cancelParent(parentCause)
cancelParent() // Prematurely cancel context before calling NotifyContext.
cancelParent(parentCause) // Prematurely cancel context before calling NotifyContext.
c, stop := NotifyContext(parent, syscall.SIGINT)
defer stop()
@ -832,6 +844,9 @@ func TestNotifyContextPrematureCancelParent(t *testing.T) {
if got := c.Err(); got != context.Canceled {
t.Errorf("c.Err() = %q, want %q", got, context.Canceled)
}
if got := context.Cause(c); got != parentCause {
t.Errorf("context.Cause(c) = %q, want %q", got, parentCause)
}
}
func TestNotifyContextSimultaneousStop(t *testing.T) {