internal/nettest: add internal fake networking implementation

This is an implementation of proposal #77362.
Add it as an internal package so we can use it for fake networking
support in net/http/httptest.

This can also eventually replace the fake network implementations
in net/http and net/http/internal/http2.

For #77362
For #76608

Change-Id: Ida069e139ede90bf01e413f2a69cf0116a6a6964
Reviewed-on: https://go-review.googlesource.com/c/go/+/769520
LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Nicholas Husin <husin@google.com>
Reviewed-by: Nicholas Husin <nsh@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
This commit is contained in:
Damien Neil 2026-04-18 11:08:35 -04:00 committed by Gopher Robot
parent 2e67b18935
commit a871fd3732
9 changed files with 1996 additions and 0 deletions

View file

@ -685,6 +685,9 @@ var depsRules = `
net/http, net/http/internal/ascii
< net/http/cookiejar, net/http/httputil;
NET, internal/gate
< internal/nettest;
net/http, flag
< net/http/httptest;

View file

@ -0,0 +1,437 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"bytes"
"errors"
"io"
"math"
"net"
"net/netip"
"os"
"time"
)
// Conn is an in-memory test implementation of net.Conn.
type Conn struct {
// Conns come in pairs.
// Writes to one Conn are read by its peer, and vice-versa.
//
// A connHalf handles one direction of data flow.
// A Conn consists of read and write halves.
// A Conn's peer has the same halves, only swapped.
//
// A Conn reads from r and writes to w.
r, w *connHalf
// peer is the other endpoint.
peer *Conn
}
// NewConnPair returns a pair of connected Conns.
func NewConnPair() (*Conn, *Conn) {
return newConnPair(
net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10000")),
net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10001")),
)
}
func newConnPair(addr1, addr2 net.Addr) (*Conn, *Conn) {
h1 := newConnHalf(addr1)
h2 := newConnHalf(addr2)
c1 := &Conn{r: h1, w: h2}
c2 := &Conn{r: h2, w: h1}
c1.peer = c2
c2.peer = c1
c1.SetReadBufferSize(-1)
c2.SetReadBufferSize(-1)
return c1, c2
}
// Peer returns the other end of the connection.
func (c *Conn) Peer() *Conn {
return c.peer
}
// Read reads data from the connection.
func (c *Conn) Read(b []byte) (n int, err error) {
n, err = c.r.read(b)
if err != nil && err != io.EOF {
err = &net.OpError{
Op: "read",
Net: "tcp",
Source: c.RemoteAddr(),
Addr: c.LocalAddr(),
Err: err,
}
}
return n, err
}
// CanRead reports whether Read can proceed without blocking.
func (c *Conn) CanRead() bool {
return c.r.canRead()
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (n int, err error) {
n, err = c.w.write(b)
if err != nil {
err = &net.OpError{
Op: "write",
Net: "tcp",
Source: c.LocalAddr(),
Addr: c.RemoteAddr(),
Err: err,
}
}
return n, err
}
// IsClosed reports whether the connection has been closed.
// A connection is closed if [CloseRead] and [CloseWrite] are both called,
// or if [Close] is called.
//
// To identify when the other side of the Conn has been closed,
// use Conn.Peer().IsClosed().
func (c *Conn) IsClosed() bool {
c.r.lock()
readClosed := c.r.readClosed
c.r.unlock()
c.w.lock()
writeClosed := c.w.writeClosed
c.w.unlock()
return readClosed && writeClosed
}
var errClosedByPeer = errors.New("connection closed by peer")
// CloseRead shuts down the reading side of the connection.
func (c *Conn) CloseRead() error {
c.r.lock()
defer c.r.unlock()
c.r.buf.Reset() // discard unread data
c.r.readClosed = true
return nil
}
// CloseWrite shuts down the writing side of the connection.
func (c *Conn) CloseWrite() error {
c.w.lock()
defer c.w.unlock()
c.w.writeClosed = true
return nil
}
// Close closes the connection.
func (c *Conn) Close() error {
c.r.lock()
readClosed := c.r.readClosed
c.r.buf.Reset() // discard unread data
c.r.readClosed = true
err := c.r.closeErr
c.r.unlock()
c.w.lock()
writeClosed := c.w.writeClosed
c.w.writeClosed = true
c.w.unlock()
if readClosed && writeClosed {
err = net.ErrClosed
}
if err != nil {
err = &net.OpError{
Op: "close",
Net: "tcp",
Addr: c.LocalAddr(),
Err: err,
}
}
return err
}
// SetCloseError sets the error returned by Close.
// Close still closes the connection.
// A nil error restores the usual behavior.
func (c *Conn) SetCloseError(err error) {
c.r.lock()
c.r.closeErr = err
c.r.unlock()
}
// LocalAddr returns the (fake) local network address.
func (c *Conn) LocalAddr() net.Addr {
c.r.lock()
defer c.r.unlock()
return c.r.addr
}
// SetLocalAddr sets the local address.
//
// To set the remote address, set the local address of Conn's peer.
func (c *Conn) SetLocalAddr(addr net.Addr) {
c.r.lock()
defer c.r.unlock()
c.r.addr = addr
}
// LocalAddr returns the (fake) remote network address.
func (c *Conn) RemoteAddr() net.Addr {
c.r.lock()
defer c.r.unlock()
return c.w.addr
}
// SetDeadline sets the read and write deadlines for the connection.
func (c *Conn) SetDeadline(t time.Time) error {
c.SetReadDeadline(t)
c.SetWriteDeadline(t)
return nil
}
// SetReadDeadline sets the read deadline for the connection.
func (c *Conn) SetReadDeadline(t time.Time) error {
c.r.readDeadline.setDeadline(c.r, t)
return nil
}
// SetWriteDeadline sets the write deadline for the connection.
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.w.writeDeadline.setDeadline(c.w, t)
return nil
}
// SetReadBufferSize sets the connection's read buffer.
// Writes to the other end of the connection will block so long as the buffer is full.
// Setting the size to 0 blocks all writes until the size is increased.
func (c *Conn) SetReadBufferSize(size int) {
if size < 0 {
size = math.MaxInt
}
c.r.setBufferSize(size)
}
// SetReadError causes any currently blocked and future Read calls to return
// a net.OpError wrapping err. It does not affect the other half of the connection.
// Reads will return any buffered data before returning the error,
// including data written after the error is set and io.EOF after the other end is closed.
// A nil error restores the usual behavior.
func (c *Conn) SetReadError(err error) {
c.r.lock()
defer c.r.unlock()
c.r.readErr = err
}
// SetWriteError causes any currently blocked and future Write calls to return
// a net.OpError wrapping err. It does not affect the other half of the connection.
// Writes will not write data to the connection buffer while an error is set.
// A nil error restores the usual behavior.
func (c *Conn) SetWriteError(err error) {
c.w.lock()
defer c.w.unlock()
c.w.writeErr = err
}
// connHalf is one direction data flow in a Conn.
// The connHalf contains a buffer.
// Writes to the connHalf push to the buffer and reads pull from it.
type connHalf struct {
addr net.Addr
// A half can be readable and/or writable.
//
// These four channels act as a lock,
// and allow waiting for readability/writability.
// When the half is unlocked, exactly one channel contains a value.
// When the half is locked, all channels are empty.
lockr chan struct{} // readable
lockw chan struct{} // writable
lockrw chan struct{} // readable and writable
lockc chan struct{} // neither readable nor writable
// Read and write timeouts.
readDeadline, writeDeadline connDeadline
bufMax int // maximum buffer size
buf bytes.Buffer
readClosed, writeClosed bool
readErr, writeErr error // errors returned by reads/writes
closeErr error // error returned by closing the conn reading from this half
}
func newConnHalf(addr net.Addr) *connHalf {
h := &connHalf{
addr: addr,
lockw: make(chan struct{}, 1),
lockr: make(chan struct{}, 1),
lockrw: make(chan struct{}, 1),
lockc: make(chan struct{}, 1),
bufMax: math.MaxInt, // unlimited
}
h.unlock()
return h
}
// lock locks h.
func (h *connHalf) lock() {
select {
case <-h.lockw: // writable
case <-h.lockr: // readable
case <-h.lockrw: // readable and writable
case <-h.lockc: // neither readable nor writable
}
}
// unlock unlocks h.
func (h *connHalf) unlock() {
canRead := h.canReadLocked()
canWrite := h.canWriteLocked()
switch {
case canRead && canWrite:
h.lockrw <- struct{}{} // readable and writable
case canRead:
h.lockr <- struct{}{} // readable
case canWrite:
h.lockw <- struct{}{} // writable
default:
h.lockc <- struct{}{} // neither readable nor writable
}
}
func (h *connHalf) canRead() bool {
h.lock()
defer h.unlock()
return h.canReadLocked()
}
func (h *connHalf) canReadLocked() bool {
return h.readErr != nil || h.readDeadline.expired || h.buf.Len() > 0 || h.readClosed || h.writeClosed
}
func (h *connHalf) canWriteLocked() bool {
return h.writeErr != nil || h.writeDeadline.expired || h.bufMax > h.buf.Len() || h.readClosed || h.writeClosed
}
// waitAndLockForRead waits until h is readable and locks it.
func (h *connHalf) waitAndLockForRead() {
select {
case <-h.lockr:
// readable
case <-h.lockrw:
// readable and writable
}
}
// waitAndLockForWrite waits until h is writable and locks it.
func (h *connHalf) waitAndLockForWrite() {
select {
case <-h.lockw:
// writable
case <-h.lockrw:
// readable and writable
}
}
func (h *connHalf) read(b []byte) (n int, err error) {
h.waitAndLockForRead()
defer h.unlock()
if h.readClosed {
return 0, net.ErrClosed
}
if h.readDeadline.expired {
return 0, os.ErrDeadlineExceeded
}
if h.buf.Len() > 0 {
return h.buf.Read(b)
}
if h.writeClosed {
return 0, io.EOF
}
return 0, h.readErr
}
func (h *connHalf) setBufferSize(size int) {
h.lock()
defer h.unlock()
h.bufMax = size
}
func (h *connHalf) write(b []byte) (n int, err error) {
for n < len(b) {
nn, err := h.writePartial(b[n:])
n += nn
if err != nil {
return n, err
}
}
return n, nil
}
func (h *connHalf) writePartial(b []byte) (n int, err error) {
h.waitAndLockForWrite()
defer h.unlock()
if h.writeClosed {
return 0, net.ErrClosed
}
if h.writeDeadline.expired {
return 0, os.ErrDeadlineExceeded
}
if h.readClosed {
return 0, errClosedByPeer
}
if h.writeErr != nil {
return 0, h.writeErr
}
writeMax := h.bufMax - h.buf.Len()
if writeMax < len(b) {
b = b[:writeMax]
}
return h.buf.Write(b)
}
type connDeadline struct {
timer *time.Timer
expired bool
}
type locker interface {
lock()
unlock()
}
func (d *connDeadline) setDeadline(mu locker, t time.Time) {
mu.lock()
defer mu.unlock()
if d.timer != nil {
d.timer.Stop()
d.timer = nil
}
if t.IsZero() {
// No deadline.
d.expired = false
return
}
expiry := time.Until(t)
if expiry <= 0 {
// Deadline has already passed.
d.expired = true
return
}
// Deadline is in the future.
d.expired = false
var timer *time.Timer
timer = time.AfterFunc(expiry, func() {
mu.lock()
defer mu.unlock()
if d.timer == timer {
d.timer = nil
d.expired = true
}
})
d.timer = timer
}

View file

@ -0,0 +1,450 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest_test
import (
"bytes"
"errors"
"internal/nettest"
"io"
"net"
"os"
"testing"
"testing/synctest"
"time"
)
func TestConnReadWrite(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
cliConn, srvConn := nettest.NewConnPair()
cliData := []byte("hello")
srvData := []byte("HELLO")
if n, err := cliConn.Write(cliData); n != len(cliData) || err != nil {
t.Fatalf("cliConn.Write(%q) = %v, %v; want %v, nil", cliData, n, err, len(cliData))
}
if err := cliConn.CloseWrite(); err != nil {
t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
}
if n, err := srvConn.Write(srvData); n != len(srvData) || err != nil {
t.Fatalf("srvConn.Write(%q) = %v, %v; want %v, nil", srvData, n, err, len(srvData))
}
if err := srvConn.CloseWrite(); err != nil {
t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
}
gotCli, err := io.ReadAll(cliConn)
if !bytes.Equal(gotCli, srvData) || err != nil {
t.Fatalf("io.ReadAll(cliConn) = %q, %v; want %v, nil", gotCli, err, srvData)
}
gotSrv, err := io.ReadAll(srvConn)
if !bytes.Equal(gotSrv, cliData) || err != nil {
t.Fatalf("io.ReadAll(srvConn) = %q, %v; want %v, nil", gotSrv, err, cliData)
}
})
}
func TestConnZeroBuffer(t *testing.T) {
// Exercise the case where one side of the conn is blocked writing and the
// other side is blocked reading.
// This can only happen when the read buffer has been set to 0, blocking all writes.
synctest.Test(t, func(t *testing.T) {
rconn, wconn := nettest.NewConnPair()
rconn.SetReadBufferSize(0)
var readDone, writeDone bool
go func() {
rconn.Read(make([]byte, 100))
readDone = true
}()
go func() {
wconn.Write([]byte("a"))
writeDone = true
}()
synctest.Wait()
if readDone || writeDone {
t.Errorf("before unblocking: readDone=%v, writeDone=%v; want false", readDone, writeDone)
}
wconn.Close()
synctest.Wait()
if !readDone || !writeDone {
t.Errorf("after unblocking: readDone=%v, writeDone=%v; want true", readDone, writeDone)
}
})
}
func TestConnPartialWrite(t *testing.T) {
// A blocking write to a conn successfully writes some, but not all data.
synctest.Test(t, func(t *testing.T) {
const readSize = 5
data := []byte("0123456789")
rconn, wconn := nettest.NewConnPair()
rconn.SetReadBufferSize(1)
go func() {
got := make([]byte, readSize)
if n, err := io.ReadFull(rconn, got); n != readSize || err != nil {
t.Errorf("io.ReadFull() = %v, %v; want %v, nil", n, err, readSize)
}
if want := data[:readSize]; !bytes.Equal(got, want) {
t.Errorf("read %q, want %q", got, want)
}
rconn.Close()
}()
n, err := wconn.Write(data)
if n != readSize+1 || err == nil {
t.Errorf("Write() = %v, %v; want %v, error", n, err, readSize+1)
}
})
}
func TestConnReadDeadline(t *testing.T) {
for _, unblock := range []struct {
name string
f func(*nettest.Conn)
}{{
name: "Write",
f: func(c *nettest.Conn) {
c.Write([]byte("x"))
},
}, {
name: "Close",
f: func(c *nettest.Conn) {
c.Close()
},
}, {
name: "CloseWrite",
f: func(c *nettest.Conn) {
c.CloseWrite()
},
}} {
for _, setDeadline := range []struct {
name string
f func(*nettest.Conn, time.Time) error
}{{
name: "SetDeadline",
f: (*nettest.Conn).SetDeadline,
}, {
name: "SetReadDeadline",
f: (*nettest.Conn).SetReadDeadline,
}} {
t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
testDeadline(t, func() deadlineTest {
rconn, wconn := nettest.NewConnPair()
return deadlineTest{
what: "Read()",
block: func() error {
_, err := rconn.Read(make([]byte, 1))
return err
},
unblock: func() {
unblock.f(wconn)
},
setDeadline: func(d time.Duration) {
setDeadline.f(rconn, time.Now().Add(d))
},
}
})
})
}
}
}
func TestConnWriteDeadline(t *testing.T) {
for _, unblock := range []struct {
name string
f func(*nettest.Conn)
}{{
name: "Read",
f: func(c *nettest.Conn) {
io.Copy(io.Discard, c)
},
}, {
name: "Close",
f: func(c *nettest.Conn) {
c.Close()
},
}, {
name: "CloseRead",
f: func(c *nettest.Conn) {
c.CloseRead()
},
}} {
for _, setDeadline := range []struct {
name string
f func(*nettest.Conn, time.Time) error
}{{
name: "SetDeadline",
f: (*nettest.Conn).SetDeadline,
}, {
name: "SetWriteDeadline",
f: (*nettest.Conn).SetWriteDeadline,
}} {
t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
testDeadline(t, func() deadlineTest {
rconn, wconn := nettest.NewConnPair()
rconn.SetReadBufferSize(1)
return deadlineTest{
what: "Write()",
block: func() error {
_, err := wconn.Write([]byte("1234"))
wconn.Close()
return err
},
unblock: func() {
go unblock.f(rconn)
},
setDeadline: func(d time.Duration) {
setDeadline.f(wconn, time.Now().Add(d))
},
}
})
})
}
}
}
func TestConnCanRead(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
rconn, wconn := nettest.NewConnPair()
if got, want := rconn.CanRead(), false; got != want {
t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
}
wconn.Write([]byte("a"))
if got, want := rconn.CanRead(), true; got != want {
t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
}
rconn.Read(make([]byte, 1))
if got, want := rconn.CanRead(), false; got != want {
t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
}
wconn.Close()
if got, want := rconn.CanRead(), true; got != want {
t.Fatalf("after closing: rconn.CanRead() = %v, want %v", got, want)
}
})
}
func TestConnIsClosed(t *testing.T) {
for _, test := range []struct {
name string
f func() *nettest.Conn
want bool
}{{
name: "unclosed",
f: func() *nettest.Conn {
conn, _ := nettest.NewConnPair()
return conn
},
want: false,
}, {
name: "closed",
f: func() *nettest.Conn {
conn, _ := nettest.NewConnPair()
conn.Close()
return conn
},
want: true,
}, {
name: "read-closed",
f: func() *nettest.Conn {
conn, _ := nettest.NewConnPair()
conn.CloseRead()
return conn
},
want: false,
}, {
name: "write-closed",
f: func() *nettest.Conn {
conn, _ := nettest.NewConnPair()
conn.CloseWrite()
return conn
},
want: false,
}, {
name: "read-write-closed",
f: func() *nettest.Conn {
conn, _ := nettest.NewConnPair()
conn.CloseRead()
conn.CloseWrite()
return conn
},
want: true,
}} {
synctestSubtest(t, test.name, func(t *testing.T) {
conn := test.f()
if got, want := conn.IsClosed(), test.want; got != want {
t.Fatalf("conn.IsClosed() = %v, want %v", got, want)
}
if got, want := conn.Peer().IsClosed(), false; got != want {
t.Fatalf("conn.Peer().IsClosed() = %v, want %v", got, want)
}
})
}
}
var anyError = errors.New("any") // anyError is passed to isOpError to match any error
func isOpError(err, want error) bool {
oe, ok := err.(*net.OpError)
return ok && (oe.Err == want || want == anyError)
}
func wantConnReadBytes(t *testing.T, c *nettest.Conn, want []byte) {
t.Helper()
got := make([]byte, len(want))
n, err := io.ReadFull(c, got)
if n < len(want) || err != nil {
t.Fatalf("io.ReadFull = %v, %v; want %v, nil", n, err, len(want))
}
if !bytes.Equal(got, want) {
t.Fatalf("io.ReadFull read %q, want %q", got, want)
}
}
func wantConnReadErr(t *testing.T, c *nettest.Conn, want error) {
t.Helper()
n, err := c.Read(make([]byte, 1))
if want == io.EOF {
if n != 0 || err != io.EOF {
t.Fatalf("c.Read() = %v, %v; want 0, io.EOF", n, err)
}
} else {
if n != 0 || !isOpError(err, want) {
t.Fatalf("c.Read() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
}
}
}
func wantConnReadBlocked(t *testing.T, c *nettest.Conn) {
done := false
go func() {
n, err := c.Read(make([]byte, 1))
if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("c.Read() = %v, %v; want 0, ErrDeadlineExceeded", n, err)
}
done = true
}()
synctest.Wait()
if done {
t.Fatalf("Read unexpectedly returned before setting deadline")
}
c.SetReadDeadline(time.Now().Add(-1 * time.Second))
synctest.Wait()
c.SetReadDeadline(time.Time{})
if !done {
t.Fatalf("Read unexpectedly did not return after setting deadline")
}
}
func TestConnSetReadError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
rconn, wconn := nettest.NewConnPair()
rconn.SetReadError(wantErr)
// Consume buffer before returning error.
wconn.Write([]byte("one"))
wantConnReadBytes(t, rconn, []byte("one"))
wantConnReadErr(t, rconn, wantErr)
// Write more to the buffer, suppressing error until buffer drains again.
wconn.Write([]byte("two"))
wantConnReadBytes(t, rconn, []byte("two"))
wantConnReadErr(t, rconn, wantErr)
// Error may be cleared.
rconn.SetReadError(nil)
wantConnReadBlocked(t, rconn)
// Close overrides read error.
rconn.SetReadError(wantErr)
wconn.Write([]byte("three"))
wconn.Close()
wantConnReadBytes(t, rconn, []byte("three"))
wantConnReadErr(t, rconn, io.EOF)
// Setting another read error does not override Close.
rconn.SetReadError(nil)
wantConnReadErr(t, rconn, io.EOF)
rconn.SetReadError(wantErr)
wantConnReadErr(t, rconn, io.EOF)
// ErrClosed takes precedence over read error.
rconn.Close()
wantConnReadErr(t, rconn, net.ErrClosed)
})
}
func wantConnWriteBytes(t *testing.T, c *nettest.Conn, b []byte) {
t.Helper()
if n, err := c.Write(b); n != len(b) || err != nil {
t.Fatalf("c.Write() = %v, %v; want %v, nil", n, err, len(b))
}
}
func wantConnWriteErr(t *testing.T, c *nettest.Conn, want error) {
t.Helper()
n, err := c.Write(make([]byte, 1))
if n != 0 || !isOpError(err, want) {
t.Fatalf("c.Write() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
}
}
func TestConnSetWriteError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
rconn, wconn := nettest.NewConnPair()
wconn.SetWriteError(wantErr)
// Error blocks writes.
wantConnWriteErr(t, wconn, wantErr)
wantConnReadBlocked(t, rconn)
// Error may be cleared.
wconn.SetWriteError(nil)
wantConnWriteBytes(t, wconn, []byte("one"))
// Restoring error does not prevent reading buffered data.
wconn.SetWriteError(wantErr)
wantConnWriteErr(t, wconn, wantErr)
wantConnReadBytes(t, rconn, []byte("one"))
// Error does not interfere with closing the conn.
wconn.Close()
wantConnReadErr(t, rconn, io.EOF)
})
}
func TestConnSetCloseError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
rconn, wconn := nettest.NewConnPair()
wconn.SetCloseError(wantErr)
if _, err := wconn.Write([]byte("one")); err != nil {
t.Fatalf("wconn.Write = %v, want success", err)
}
if err := wconn.Close(); !isOpError(err, wantErr) {
t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
}
if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
}
wantConnReadBytes(t, rconn, []byte("one"))
wantConnReadErr(t, rconn, io.EOF)
})
}
func TestConnCloseReadWriteError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
conn, _ := nettest.NewConnPair()
conn.SetCloseError(errors.New("error"))
if err := conn.CloseRead(); err != nil {
t.Fatalf("conn.CloseRead = %v, want nil", err)
}
if err := conn.CloseWrite(); err != nil {
t.Fatalf("conn.CloseRead = %v, want nil", err)
}
})
}

View file

@ -0,0 +1,148 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"internal/gate"
"net"
"net/netip"
)
// Listener is an in-memory test implementation of net.Listener.
type Listener struct {
gate gate.Gate
queue queue[*Conn]
closed bool
acceptErr error
closeErr error
addr net.Addr
nextaddr netip.AddrPort
}
// NewListener returns a new Listener.
func NewListener() *Listener {
return &Listener{
gate: gate.New(false),
addr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1000")),
nextaddr: netip.MustParseAddrPort("127.0.0.1:10001"),
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (li *Listener) Close() error {
li.lock()
defer li.unlock()
err := li.closeErr
li.closed = true
li.acceptErr = net.ErrClosed
li.closeErr = net.ErrClosed
if err != nil {
return &net.OpError{
Op: "close",
Net: "tcp",
Addr: li.addr,
Err: err,
}
}
return err
}
// Addr returns the listener's network address.
//
// The address is always a *net.TCPAddr.
func (li *Listener) Addr() net.Addr {
li.lock()
defer li.unlock()
return li.addr
}
// SetAddr sets the listener's network address.
func (li *Listener) SetAddr(addr net.Addr) {
li.lock()
defer li.unlock()
li.addr = addr
}
// NewConn returns a new connection to the listener.
//
// Accept will return the other side of the conn.
func (li *Listener) NewConn() *Conn {
return li.NewConnConfig(func(*Conn) {})
}
// NewConnConfig returns a new connection to the listener.
//
// The function f is called with the new client connection.
// After f returns, Accept will return the other side of the connection.
//
// For example, to create a connection from a specific IP address:
//
// conn := li.NewConnConfig(func(conn *nettest.Conn) {
// conn.SetLocalAddr(net.TCPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1234")))
// })
func (li *Listener) NewConnConfig(f func(*Conn)) *Conn {
li.lock()
defer li.unlock()
cli, srv := newConnPair(
net.TCPAddrFromAddrPort(li.nextaddr),
li.addr,
)
li.nextaddr = netip.AddrPortFrom(li.nextaddr.Addr(), li.nextaddr.Port()+1)
f(cli)
li.queue.push(srv)
return cli
}
// Accept waits for and returns the next connection to the listener.
//
// The connections returned by Accept are always [*Conn]s.
func (li *Listener) Accept() (net.Conn, error) {
li.gate.WaitAndLock(context.Background())
defer li.unlock()
if li.acceptErr != nil && li.queue.len() == 0 {
return nil, &net.OpError{
Op: "accept",
Net: "tcp",
Addr: li.addr,
Err: li.acceptErr,
}
}
return li.queue.pop(), nil
}
// SetAcceptError causes any currently blocked and future Accept calls to return
// a net.OpError wrapping err.
// Accept will return any available connections before returning the error,
// including connections created after the error is set.
// A nil error restores the usual behavior.
func (li *Listener) SetAcceptError(err error) {
li.gate.Lock()
defer li.unlock()
if !li.closed {
li.acceptErr = err
}
}
// SetCloseError sets the error returned by Close.
// Close still closes the listener.
// A nil error restores the usual behavior.
func (li *Listener) SetCloseError(err error) {
li.gate.Lock()
defer li.unlock()
if !li.closed {
li.closeErr = err
}
}
func (li *Listener) lock() {
li.gate.Lock()
}
func (li *Listener) unlock() {
li.gate.Unlock(li.acceptErr != nil || li.queue.len() > 0)
}

View file

@ -0,0 +1,212 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest_test
import (
"errors"
"internal/nettest"
"io"
"net"
"net/netip"
"slices"
"testing"
"testing/synctest"
)
func TestListenerNewConn(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
li := nettest.NewListener()
defer li.Close()
// Create several connections in parallel.
want := []string{"a", "b", "c"}
for i := range len(want) {
go func() {
conn := li.NewConn()
defer conn.Close()
n, err := conn.Write([]byte(want[i]))
if n != len(want[i]) || err != nil {
t.Errorf("conn%v.Write() = %v, %v; want %v, nil", i, n, err, len(want[i]))
}
}()
}
// Accept the connections in parallel as well.
got := make([]string, len(want))
for i := range len(want) {
go func() {
conn, err := li.Accept()
if err != nil {
t.Errorf("li.Accept() = %v", err)
}
b, err := io.ReadAll(conn)
if err != nil {
t.Errorf("io.ReadAll(conn%v) = %v", i, err)
}
got[i] = string(b)
}()
}
synctest.Wait()
slices.Sort(got)
slices.Sort(want)
if !slices.Equal(got, want) {
t.Errorf("connections read %v; want %q", got, want)
}
})
}
func TestListenerInterruptAccept(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
li := nettest.NewListener()
var acceptErr error
go func() {
_, acceptErr = li.Accept()
}()
synctest.Wait()
if acceptErr != nil {
t.Fatalf("li.Accept() = %v, want still running before close", acceptErr)
}
li.Close()
synctest.Wait()
if !errors.Is(acceptErr, net.ErrClosed) {
t.Fatalf("li.Accept() = %v, want ErrClosed", acceptErr)
}
})
}
func TestListenerAddresses(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
srvaddr := netip.MustParseAddrPort("10.0.0.1:80")
cliaddr := netip.MustParseAddrPort("10.0.0.2:1234")
li := nettest.NewListener()
defer li.Close()
li.SetAddr(net.TCPAddrFromAddrPort(srvaddr))
if got, want := li.Addr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
t.Errorf("li.Addr() = %v, want %v", got, want)
}
cli := li.NewConnConfig(func(conn *nettest.Conn) {
conn.SetLocalAddr(net.TCPAddrFromAddrPort(cliaddr))
})
srvc, err := li.Accept()
if err != nil {
t.Fatalf("li.Accept() = %v", err)
}
srv := srvc.(*nettest.Conn)
if cli.Peer() != srv {
t.Errorf("cli.Peer() != srv; should be the same")
}
if srv.Peer() != cli {
t.Errorf("cli.Peer() != srv; should be the same")
}
if got, want := cli.LocalAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
}
if got, want := cli.RemoteAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
}
if got, want := srv.LocalAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
t.Errorf("srv.LocalAddr() = %v, want %v", got, want)
}
if got, want := srv.RemoteAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
}
})
}
func wantListenerAccept(t *testing.T, li *nettest.Listener, want *nettest.Conn) {
t.Helper()
got, err := li.Accept()
if err != nil {
t.Fatalf("li.Accept() = %v, want conn", err)
}
if got != want {
t.Fatalf("li.Accept() returned unexpected conn")
}
}
func wantListenerAcceptErr(t *testing.T, li *nettest.Listener, want error) {
t.Helper()
got, err := li.Accept()
if got != nil || !isOpError(err, want) {
t.Fatalf("li.Accept() = %p, %v; want nil, OpError{Err: %q}", got, err, want)
}
}
func wantListenerAcceptBlocked(t *testing.T, li *nettest.Listener) {
cancelErr := errors.New("cancel")
done := false
go func() {
got, err := li.Accept()
if got != nil || !errors.Is(err, cancelErr) {
t.Errorf("li.Accept = %p, %v; want nil, cancelErr", got, err)
}
done = true
}()
synctest.Wait()
if done {
t.Fatalf("Accept unexpectedly returned before canceling")
}
li.SetAcceptError(cancelErr)
synctest.Wait()
li.SetAcceptError(nil)
if !done {
t.Fatalf("Accept unexpectedly did not return after canceling")
}
}
func TestListenerSetAcceptError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
acceptErr := errors.New("accept error")
li := nettest.NewListener()
defer li.Close()
li.SetAcceptError(acceptErr)
// Accept conns from queue before returning error.
c1 := li.NewConn()
wantListenerAccept(t, li, c1.Peer())
wantListenerAcceptErr(t, li, acceptErr)
// Add a new conn, suppressing error until the queue is empty.
c2 := li.NewConn()
wantListenerAccept(t, li, c2.Peer())
wantListenerAcceptErr(t, li, acceptErr)
// Error may be cleared.
li.SetAcceptError(nil)
wantListenerAcceptBlocked(t, li)
// ErrClosed takes precedence over accept error.
li.SetAcceptError(acceptErr)
li.Close()
wantListenerAcceptErr(t, li, net.ErrClosed)
})
}
func TestListenerSetCloseError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
li := nettest.NewListener()
closeErr := errors.New("close error")
li.SetCloseError(closeErr)
// First close uses the user-provided error.
if err := li.Close(); !isOpError(err, closeErr) {
t.Fatalf("li.Close() = %v; want OpError wrapping accept error", err)
}
// Repeated closes return ErrClosed.
if err := li.Close(); !isOpError(err, net.ErrClosed) {
t.Fatalf("li.Close() = %v; want OpError wrapping net.ErrClosed", err)
}
})
}

View file

@ -0,0 +1,134 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest_test
import (
"errors"
"internal/nettest"
"net"
"os"
"sync/atomic"
"testing"
"testing/synctest"
"time"
)
var (
_ net.Conn = (*nettest.Conn)(nil)
_ net.Listener = (*nettest.Listener)(nil)
_ net.PacketConn = (*nettest.PacketConn)(nil)
)
func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) {
t.Run(name, func(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
f(t)
})
})
}
// A deadlineTest describes an operation which blocks until a deadline,
// and a separate operation which unblocks it.
type deadlineTest struct {
what string // name of the blocking op; e.g., "Read"
block func() error // blocking op; e.g., reading from a conn
unblock func() // unblocking op; e.g. writing to the other side
setDeadline func(d time.Duration) // deadline func; e.g., SetReadDeadline
}
// testDeadline tests a variety of scenarios involving deadlines.
func testDeadline(t *testing.T, setup func() deadlineTest) {
synctestSubtest(t, "no deadline", func(t *testing.T) {
test := setup()
test.unblock()
synctest.Wait()
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
}
})
synctestSubtest(t, "unblock before setdeadline", func(t *testing.T) {
test := setup()
test.unblock()
synctest.Wait()
test.setDeadline(5 * time.Second)
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
}
})
synctestSubtest(t, "unblock after blocking", func(t *testing.T) {
test := setup()
test.setDeadline(5 * time.Second)
var done bool
go func() {
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
}
done = true
}()
synctest.Wait()
if done {
t.Fatalf("%v: unexpectedly returned before unblocking", test.what)
}
test.unblock()
synctest.Wait()
if !done {
t.Fatalf("%v: did not return after unblocking", test.what)
}
})
synctestSubtest(t, "deadline expires", func(t *testing.T) {
test := setup()
start := time.Now()
const delay = 5 * time.Second
test.setDeadline(delay)
var done atomic.Bool
go func() {
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
}
if got, want := time.Since(start), delay; got != want {
t.Errorf("%v: returned after %v, want %v", test.what, got, want)
}
done.Store(true)
}()
synctest.Wait()
if done.Load() {
t.Fatalf("%v: unexpectedly returned before unblocking", test.what)
}
time.Sleep(delay)
synctest.Wait()
if !done.Load() {
t.Fatalf("%v: did not return after deadline", test.what)
}
})
synctestSubtest(t, "deadline already expired", func(t *testing.T) {
test := setup()
test.setDeadline(-1 * time.Second)
test.unblock()
synctest.Wait()
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
}
})
synctestSubtest(t, "reduce deadline after blocking", func(t *testing.T) {
test := setup()
test.setDeadline(5 * time.Second)
var done bool
go func() {
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
}
done = true
}()
synctest.Wait()
if done {
t.Fatalf("%v: unexpectedly returned before reducing deadline", test.what)
}
test.setDeadline(-1 * time.Second)
synctest.Wait()
if !done {
t.Fatalf("%v: did not return after deadline", test.what)
}
})
}

View file

@ -0,0 +1,250 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"errors"
"internal/gate"
"net"
"os"
"slices"
"sync"
"time"
)
// A PacketNet is a group of communicating [PacketConn]s.
type PacketNet struct {
mu sync.Mutex
conns map[netAddr]*PacketConn
}
type netAddr struct {
network string
addr string
}
// NewPacketNet returns a new PacketNet.
func NewPacketNet() *PacketNet {
return &PacketNet{
conns: make(map[netAddr]*PacketConn),
}
}
// NewConn returns a new [PacketConn] listening on the given address.
// It returns an error if there is an existing listener on this address.
func (n *PacketNet) NewConn(a net.Addr) (*PacketConn, error) {
n.mu.Lock()
defer n.mu.Unlock()
addrKey := netAddr{a.Network(), a.String()}
if _, ok := n.conns[addrKey]; ok {
return nil, &net.OpError{
Op: "listen",
Net: "udp",
Addr: a,
Err: errors.New("address is in use"),
}
}
p := &PacketConn{
gate: gate.New(false),
addr: a,
net: n,
}
n.conns[addrKey] = p
return p, nil
}
type PacketConn struct {
gate gate.Gate
queue queue[*packet]
closed bool
readErr error
writeErr error
closeErr error
readDeadline connDeadline
net *PacketNet
addr net.Addr
}
type packet struct {
b []byte
src net.Addr
}
// ReadFrom reads a packet from the connection, copying the payload into b.
func (p *PacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
p.gate.WaitAndLock(context.Background())
defer p.unlock()
switch {
case p.closed:
err = net.ErrClosed
case p.readDeadline.expired:
err = os.ErrDeadlineExceeded
case p.queue.len() == 0 && p.readErr != nil:
err = p.readErr
}
if err != nil {
return 0, nil, &net.OpError{
Op: "read",
Net: "udp",
Addr: p.addr,
Err: err,
}
}
pkt := p.queue.pop()
n = copy(b, pkt.b)
return n, pkt.src, nil
}
// WriteTo writes a packet with payload b to addr.
// addr must be a [*net.UDPAddr].
//
// WriteTo appends the packet to the recipient's receive buffer.
// If no recipient is listening on addr or if the recipient's
// receive buffer is full, the packet is silently discarded.
func (p *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
p.gate.Lock()
switch {
case p.closed:
err = net.ErrClosed
case p.writeErr != nil:
err = p.writeErr
}
p.unlock()
if err != nil {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Source: p.addr,
Addr: addr,
Err: err,
}
}
p.net.mu.Lock()
dst := p.net.conns[netAddr{addr.Network(), addr.String()}]
p.net.mu.Unlock()
if dst == nil {
// There is no PacketConn listening on the destination address,
// and the packet falls silently into the void.
return len(b), nil
}
dst.lock()
if !dst.closed {
dst.queue.push(&packet{b: slices.Clone(b), src: p.addr})
}
dst.unlock()
return len(b), nil
}
// Close closes the connection.
func (p *PacketConn) Close() error {
p.net.mu.Lock()
delete(p.net.conns, netAddr{p.addr.Network(), p.addr.String()})
p.net.mu.Unlock()
p.lock()
defer p.unlock()
err := p.closeErr
p.closed = true
p.readErr = net.ErrClosed
p.writeErr = net.ErrClosed
p.closeErr = net.ErrClosed
if err != nil {
return &net.OpError{
Op: "close",
Net: "udp",
Addr: p.addr,
Err: err,
}
}
return err
}
// LocalAddr returns the (fake) local network address.
func (p *PacketConn) LocalAddr() net.Addr {
p.lock()
defer p.unlock()
return p.addr
}
// SetReadDeadline sets the read deadline for the connection.
// PacketConns have no write deadline.
func (p *PacketConn) SetDeadline(t time.Time) error {
return p.SetReadDeadline(t)
}
// SetReadDeadline sets the read deadline for the connection.
func (p *PacketConn) SetReadDeadline(t time.Time) error {
p.readDeadline.setDeadline(p, t)
return nil
}
// SetWriteDeadline has no effect.
// Writes to PacketConns never block.
func (p *PacketConn) SetWriteDeadline(t time.Time) error {
return nil
}
// SetReadError causes any currently blocked and future ReadFrom calls to return
// a net.OpError wrapping err. It does not affect the other half of the connection.
// Reads will return any buffered data before returning the error,
// including data written after the error is set.
// A nil error restores the usual behavior.
func (c *PacketConn) SetReadError(err error) {
c.lock()
defer c.unlock()
c.readErr = err
}
// SetWriteError causes any currently blocked and future WriteTo calls to return
// a net.OpError wrapping err. It does not affect the other half of the connection.
// Writes will not write data while an error is set.
// A nil error restores the usual behavior.
func (c *PacketConn) SetWriteError(err error) {
c.lock()
defer c.unlock()
c.writeErr = err
}
// SetCloseError sets the error returned by Close.
// Close still closes the connection.
// A nil error restores the usual behavior.
func (c *PacketConn) SetCloseError(err error) {
c.lock()
defer c.unlock()
c.closeErr = err
}
// CanRead reports whether [ReadFrom] can return at least one byte or an error.
// If [ReadFrom] would block, CanRead returns false.
func (p *PacketConn) CanRead() bool {
p.lock()
defer p.unlock()
return p.canReadLocked()
}
func (p *PacketConn) canReadLocked() bool {
return p.queue.len() > 0 || p.readDeadline.expired || p.closed || p.readErr != nil
}
// IsClosed reports whether the connection has been closed.
func (p *PacketConn) IsClosed() bool {
p.lock()
defer p.unlock()
return p.closed
}
func (p *PacketConn) lock() {
p.gate.Lock()
}
func (p *PacketConn) unlock() {
p.gate.Unlock(p.canReadLocked())
}

View file

@ -0,0 +1,329 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest_test
import (
"bytes"
"errors"
"internal/nettest"
"io"
"net"
"net/netip"
"os"
"testing"
"testing/synctest"
"time"
)
func TestPacketConnListenConflict(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
addr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1000"))
pnet := nettest.NewPacketNet()
conn, err := pnet.NewConn(addr)
if err != nil {
t.Fatalf("with no existing listener, pnet.NewConn(%v) = %v; want success", addr, err)
}
_, err = pnet.NewConn(addr)
if err == nil {
t.Fatalf("with existing listener, pnet.NewConn(%v) = nil; want error", addr)
}
conn.Close()
_, err = pnet.NewConn(addr)
if err != nil {
t.Fatalf("after closing existing listener, pnet.NewConn(%v) = %v; want success", addr, err)
}
})
}
func TestPacketConnReadWrite(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
c1 := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
c2 := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
c3 := mustNewPacketConn(t, pnet, "10.0.0.3:3000")
wantPacketConnWriteTo(t, c1, []byte("1->3"), c3.LocalAddr())
wantPacketConnWriteTo(t, c2, []byte("2->3"), c3.LocalAddr())
wantPacketConnWriteTo(t, c3, []byte("3->1"), c1.LocalAddr())
wantPacketConnReadBytes(t, c1, []byte("3->1"), c3.LocalAddr())
wantPacketConnReadBytes(t, c3, []byte("1->3"), c1.LocalAddr())
wantPacketConnReadBytes(t, c3, []byte("2->3"), c2.LocalAddr())
wantPacketConnReadBlocked(t, c1)
wantPacketConnReadBlocked(t, c2)
wantPacketConnReadBlocked(t, c3)
// Write a packet into the void (no listener on this address).
wantPacketConnWriteTo(t, c1, []byte("1->lost"), net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.100:1000")))
})
}
func TestPacketConnWriteAddressErrors(t *testing.T) {
t.Skip("TODO: figure out if these should be errors")
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
c4 := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
c6 := mustNewPacketConn(t, pnet, "[::1]:1000")
wantPacketConnWriteErr(t, c4, c6.LocalAddr(), anyError) // IPv4 -> IPv6
wantPacketConnWriteErr(t, c6, c4.LocalAddr(), anyError) // IPv6 -> IPv4
// Not a *net.UDPAddr.
wantPacketConnWriteErr(t, c4, net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1000")), anyError)
})
}
func TestPacketConnClose(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
pconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wantPacketConnWriteTo(t, pconn, []byte("hello"), pconn.LocalAddr())
if err := pconn.Close(); err != nil {
t.Errorf("pconn.Close() = %v, want success", err)
}
if err := pconn.Close(); !isOpError(err, net.ErrClosed) {
t.Errorf("pconn.Close() = %v, want ErrClosed", err)
}
wantPacketConnReadErr(t, pconn, net.ErrClosed)
wantPacketConnWriteErr(t, pconn, pconn.LocalAddr(), net.ErrClosed)
})
}
func TestPacketConnReadDeadline(t *testing.T) {
for _, setDeadline := range []struct {
name string
f func(*nettest.PacketConn, time.Time) error
}{{
name: "SetDeadline",
f: (*nettest.PacketConn).SetDeadline,
}, {
name: "SetReadDeadline",
f: (*nettest.PacketConn).SetReadDeadline,
}} {
t.Run(setDeadline.name, func(t *testing.T) {
testDeadline(t, func() deadlineTest {
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
return deadlineTest{
what: "ReadFrom()",
block: func() error {
_, _, err := rconn.ReadFrom(make([]byte, 1))
return err
},
unblock: func() {
wconn.WriteTo([]byte("x"), rconn.LocalAddr())
},
setDeadline: func(d time.Duration) {
setDeadline.f(rconn, time.Now().Add(d))
},
}
})
})
}
}
func TestPacketConnWriteDeadline(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
// This does nothing, even though the deadline has expired.
wconn.SetWriteDeadline(time.Now().Add(-1 * time.Second))
wantPacketConnWriteTo(t, wconn, []byte("data"), rconn.LocalAddr())
wantPacketConnReadBytes(t, rconn, []byte("data"), wconn.LocalAddr())
})
}
func TestPacketConnSetReadError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
rconn.SetReadError(wantErr)
// Consume buffer before returning error.
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
wantPacketConnReadErr(t, rconn, wantErr)
// Write more, suppressing error until buffer drains again.
wantPacketConnWriteTo(t, wconn, []byte("two"), rconn.LocalAddr())
wantPacketConnReadBytes(t, rconn, []byte("two"), wconn.LocalAddr())
wantPacketConnReadErr(t, rconn, wantErr)
// Error may be cleared.
rconn.SetReadError(nil)
wantPacketConnReadBlocked(t, rconn)
// ErrClosed takes precedence over read error.
rconn.SetReadError(wantErr)
rconn.Close()
wantPacketConnReadErr(t, rconn, net.ErrClosed)
})
}
func TestPacketConnSetWriteError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
wconn.SetWriteError(wantErr)
// Error blocks writes.
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), wantErr)
wantPacketConnReadBlocked(t, rconn)
// Error may be cleared.
wconn.SetWriteError(nil)
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
// Restoring error does not prevent reading buffered data.
wconn.SetWriteError(wantErr)
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), wantErr)
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
// Error does not interfere with closing the conn.
wconn.Close()
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), net.ErrClosed)
})
}
func TestPacketConnSetCloseError(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
wantErr := errors.New("error")
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
wconn.SetCloseError(wantErr)
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
if err := wconn.Close(); !isOpError(err, wantErr) {
t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
}
if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
}
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
})
}
func TestPacketConnCanRead(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
if got, want := rconn.CanRead(), false; got != want {
t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
}
wconn.WriteTo([]byte("a"), rconn.LocalAddr())
if got, want := rconn.CanRead(), true; got != want {
t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
}
rconn.ReadFrom(make([]byte, 1))
if got, want := rconn.CanRead(), false; got != want {
t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
}
})
}
func TestPacketConnIsClosed(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
pnet := nettest.NewPacketNet()
conn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
if got, want := conn.IsClosed(), false; got != want {
t.Fatalf("before closing: conn.IsClosed() = %v, want %v", got, want)
}
conn.Close()
if got, want := conn.IsClosed(), true; got != want {
t.Fatalf("after closing: conn.IsClosed() = %v, want %v", got, want)
}
})
}
func mustNewPacketConn(t *testing.T, pnet *nettest.PacketNet, addr string) *nettest.PacketConn {
t.Helper()
c, err := pnet.NewConn(net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addr)))
if err != nil {
t.Fatal(err)
}
return c
}
func wantPacketConnWriteTo(t *testing.T, c *nettest.PacketConn, b []byte, dst net.Addr) {
t.Helper()
if n, err := c.WriteTo(b, dst); n != len(b) || err != nil {
t.Fatalf("conn.WriteTo(%q, %v) = %v, %v; want %v, nil", b, dst, n, err, len(b))
}
}
func wantPacketConnWriteErr(t *testing.T, c *nettest.PacketConn, dst net.Addr, want error) {
t.Helper()
n, err := c.WriteTo(make([]byte, 1), dst)
if n != 0 || !isOpError(err, want) {
t.Fatalf("c.WriteTo() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
}
}
func wantPacketConnReadBytes(t *testing.T, c *nettest.PacketConn, want []byte, wantAddr net.Addr) {
t.Helper()
udpWantAddr, ok := wantAddr.(*net.UDPAddr)
if !ok {
t.Fatalf("wantAddr is %T, should be *net.UDPAddr", wantAddr)
}
got := make([]byte, len(want)+1)
n, addr, err := c.ReadFrom(got)
got = got[:n]
udpAddr, addrOK := addr.(*net.UDPAddr)
if n != len(want) || !addrOK || udpAddr.AddrPort() != udpWantAddr.AddrPort() {
t.Fatalf("conn.ReadFrom() = %v, %v, %v; want %v, %v, nil", n, addr, err, len(want), wantAddr)
}
if !bytes.Equal(got, want) {
t.Fatalf("conn.ReadFrom() read %q, want %q", got, want)
}
}
func wantPacketConnReadErr(t *testing.T, c *nettest.PacketConn, want error) {
t.Helper()
n, addr, err := c.ReadFrom(make([]byte, 1))
if want == io.EOF {
if n != 0 || err != io.EOF {
t.Fatalf("c.ReadFrom() = %v, %v, %v; want 0, nil, io.EOF", n, addr, err)
}
} else {
if n != 0 || !isOpError(err, want) {
t.Fatalf("c.ReadFrom() = %v, %v, %v; want 0, nil, OpError{Err: %q}", n, addr, err, want)
}
}
}
func wantPacketConnReadBlocked(t *testing.T, c *nettest.PacketConn) {
done := false
go func() {
n, addr, err := c.ReadFrom(make([]byte, 1))
if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
t.Errorf("c.Read() = %v, %v, %v; want 0, nil, ErrDeadlineExceeded", n, addr, err)
}
done = true
}()
synctest.Wait()
if done {
t.Fatalf("ReadFrom unexpectedly returned before setting deadline")
}
c.SetReadDeadline(time.Now().Add(-1 * time.Second))
synctest.Wait()
c.SetReadDeadline(time.Time{})
if !done {
t.Fatalf("ReadFrom unexpectedly did not return after setting deadline")
}
}

View file

@ -0,0 +1,33 @@
// Copyright 2026 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
type queue[T any] struct {
headPos int
head []T
tail []T
}
func (q *queue[T]) len() int {
return len(q.head[q.headPos:]) + len(q.tail)
}
func (q *queue[T]) push(v T) {
q.tail = append(q.tail, v)
}
func (q *queue[T]) pop() T {
var zero T
if q.headPos >= len(q.head) {
if len(q.tail) == 0 {
return zero
}
q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
}
v := q.head[q.headPos]
q.head[q.headPos] = zero
q.headPos++
return v
}