mirror of
https://github.com/golang/go.git
synced 2026-06-27 19:30:52 +00:00
internal/nettest: add internal fake networking implementation
This is an implementation of proposal #77362. Add it as an internal package so we can use it for fake networking support in net/http/httptest. This can also eventually replace the fake network implementations in net/http and net/http/internal/http2. For #77362 For #76608 Change-Id: Ida069e139ede90bf01e413f2a69cf0116a6a6964 Reviewed-on: https://go-review.googlesource.com/c/go/+/769520 LUCI-TryBot-Result: golang-scoped@luci-project-accounts.iam.gserviceaccount.com <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Nicholas Husin <husin@google.com> Reviewed-by: Nicholas Husin <nsh@golang.org> Auto-Submit: Damien Neil <dneil@google.com>
This commit is contained in:
parent
2e67b18935
commit
a871fd3732
9 changed files with 1996 additions and 0 deletions
|
|
@ -685,6 +685,9 @@ var depsRules = `
|
|||
net/http, net/http/internal/ascii
|
||||
< net/http/cookiejar, net/http/httputil;
|
||||
|
||||
NET, internal/gate
|
||||
< internal/nettest;
|
||||
|
||||
net/http, flag
|
||||
< net/http/httptest;
|
||||
|
||||
|
|
|
|||
437
src/internal/nettest/conn.go
Normal file
437
src/internal/nettest/conn.go
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn is an in-memory test implementation of net.Conn.
|
||||
type Conn struct {
|
||||
// Conns come in pairs.
|
||||
// Writes to one Conn are read by its peer, and vice-versa.
|
||||
//
|
||||
// A connHalf handles one direction of data flow.
|
||||
// A Conn consists of read and write halves.
|
||||
// A Conn's peer has the same halves, only swapped.
|
||||
//
|
||||
// A Conn reads from r and writes to w.
|
||||
r, w *connHalf
|
||||
|
||||
// peer is the other endpoint.
|
||||
peer *Conn
|
||||
}
|
||||
|
||||
// NewConnPair returns a pair of connected Conns.
|
||||
func NewConnPair() (*Conn, *Conn) {
|
||||
return newConnPair(
|
||||
net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10000")),
|
||||
net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:10001")),
|
||||
)
|
||||
}
|
||||
|
||||
func newConnPair(addr1, addr2 net.Addr) (*Conn, *Conn) {
|
||||
h1 := newConnHalf(addr1)
|
||||
h2 := newConnHalf(addr2)
|
||||
c1 := &Conn{r: h1, w: h2}
|
||||
c2 := &Conn{r: h2, w: h1}
|
||||
c1.peer = c2
|
||||
c2.peer = c1
|
||||
c1.SetReadBufferSize(-1)
|
||||
c2.SetReadBufferSize(-1)
|
||||
return c1, c2
|
||||
}
|
||||
|
||||
// Peer returns the other end of the connection.
|
||||
func (c *Conn) Peer() *Conn {
|
||||
return c.peer
|
||||
}
|
||||
|
||||
// Read reads data from the connection.
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.r.read(b)
|
||||
if err != nil && err != io.EOF {
|
||||
err = &net.OpError{
|
||||
Op: "read",
|
||||
Net: "tcp",
|
||||
Source: c.RemoteAddr(),
|
||||
Addr: c.LocalAddr(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// CanRead reports whether Read can proceed without blocking.
|
||||
func (c *Conn) CanRead() bool {
|
||||
return c.r.canRead()
|
||||
}
|
||||
|
||||
// Write writes data to the connection.
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.w.write(b)
|
||||
if err != nil {
|
||||
err = &net.OpError{
|
||||
Op: "write",
|
||||
Net: "tcp",
|
||||
Source: c.LocalAddr(),
|
||||
Addr: c.RemoteAddr(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// IsClosed reports whether the connection has been closed.
|
||||
// A connection is closed if [CloseRead] and [CloseWrite] are both called,
|
||||
// or if [Close] is called.
|
||||
//
|
||||
// To identify when the other side of the Conn has been closed,
|
||||
// use Conn.Peer().IsClosed().
|
||||
func (c *Conn) IsClosed() bool {
|
||||
c.r.lock()
|
||||
readClosed := c.r.readClosed
|
||||
c.r.unlock()
|
||||
c.w.lock()
|
||||
writeClosed := c.w.writeClosed
|
||||
c.w.unlock()
|
||||
return readClosed && writeClosed
|
||||
}
|
||||
|
||||
var errClosedByPeer = errors.New("connection closed by peer")
|
||||
|
||||
// CloseRead shuts down the reading side of the connection.
|
||||
func (c *Conn) CloseRead() error {
|
||||
c.r.lock()
|
||||
defer c.r.unlock()
|
||||
c.r.buf.Reset() // discard unread data
|
||||
c.r.readClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the writing side of the connection.
|
||||
func (c *Conn) CloseWrite() error {
|
||||
c.w.lock()
|
||||
defer c.w.unlock()
|
||||
c.w.writeClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.r.lock()
|
||||
readClosed := c.r.readClosed
|
||||
c.r.buf.Reset() // discard unread data
|
||||
c.r.readClosed = true
|
||||
err := c.r.closeErr
|
||||
c.r.unlock()
|
||||
|
||||
c.w.lock()
|
||||
writeClosed := c.w.writeClosed
|
||||
c.w.writeClosed = true
|
||||
c.w.unlock()
|
||||
|
||||
if readClosed && writeClosed {
|
||||
err = net.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
err = &net.OpError{
|
||||
Op: "close",
|
||||
Net: "tcp",
|
||||
Addr: c.LocalAddr(),
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// SetCloseError sets the error returned by Close.
|
||||
// Close still closes the connection.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *Conn) SetCloseError(err error) {
|
||||
c.r.lock()
|
||||
c.r.closeErr = err
|
||||
c.r.unlock()
|
||||
}
|
||||
|
||||
// LocalAddr returns the (fake) local network address.
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
c.r.lock()
|
||||
defer c.r.unlock()
|
||||
return c.r.addr
|
||||
}
|
||||
|
||||
// SetLocalAddr sets the local address.
|
||||
//
|
||||
// To set the remote address, set the local address of Conn's peer.
|
||||
func (c *Conn) SetLocalAddr(addr net.Addr) {
|
||||
c.r.lock()
|
||||
defer c.r.unlock()
|
||||
c.r.addr = addr
|
||||
}
|
||||
|
||||
// LocalAddr returns the (fake) remote network address.
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
c.r.lock()
|
||||
defer c.r.unlock()
|
||||
return c.w.addr
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines for the connection.
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
c.SetReadDeadline(t)
|
||||
c.SetWriteDeadline(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for the connection.
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
c.r.readDeadline.setDeadline(c.r, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline for the connection.
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
c.w.writeDeadline.setDeadline(c.w, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadBufferSize sets the connection's read buffer.
|
||||
// Writes to the other end of the connection will block so long as the buffer is full.
|
||||
// Setting the size to 0 blocks all writes until the size is increased.
|
||||
func (c *Conn) SetReadBufferSize(size int) {
|
||||
if size < 0 {
|
||||
size = math.MaxInt
|
||||
}
|
||||
c.r.setBufferSize(size)
|
||||
}
|
||||
|
||||
// SetReadError causes any currently blocked and future Read calls to return
|
||||
// a net.OpError wrapping err. It does not affect the other half of the connection.
|
||||
// Reads will return any buffered data before returning the error,
|
||||
// including data written after the error is set and io.EOF after the other end is closed.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *Conn) SetReadError(err error) {
|
||||
c.r.lock()
|
||||
defer c.r.unlock()
|
||||
c.r.readErr = err
|
||||
}
|
||||
|
||||
// SetWriteError causes any currently blocked and future Write calls to return
|
||||
// a net.OpError wrapping err. It does not affect the other half of the connection.
|
||||
// Writes will not write data to the connection buffer while an error is set.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *Conn) SetWriteError(err error) {
|
||||
c.w.lock()
|
||||
defer c.w.unlock()
|
||||
c.w.writeErr = err
|
||||
}
|
||||
|
||||
// connHalf is one direction data flow in a Conn.
|
||||
// The connHalf contains a buffer.
|
||||
// Writes to the connHalf push to the buffer and reads pull from it.
|
||||
type connHalf struct {
|
||||
addr net.Addr
|
||||
|
||||
// A half can be readable and/or writable.
|
||||
//
|
||||
// These four channels act as a lock,
|
||||
// and allow waiting for readability/writability.
|
||||
// When the half is unlocked, exactly one channel contains a value.
|
||||
// When the half is locked, all channels are empty.
|
||||
lockr chan struct{} // readable
|
||||
lockw chan struct{} // writable
|
||||
lockrw chan struct{} // readable and writable
|
||||
lockc chan struct{} // neither readable nor writable
|
||||
|
||||
// Read and write timeouts.
|
||||
readDeadline, writeDeadline connDeadline
|
||||
|
||||
bufMax int // maximum buffer size
|
||||
buf bytes.Buffer
|
||||
|
||||
readClosed, writeClosed bool
|
||||
readErr, writeErr error // errors returned by reads/writes
|
||||
closeErr error // error returned by closing the conn reading from this half
|
||||
}
|
||||
|
||||
func newConnHalf(addr net.Addr) *connHalf {
|
||||
h := &connHalf{
|
||||
addr: addr,
|
||||
lockw: make(chan struct{}, 1),
|
||||
lockr: make(chan struct{}, 1),
|
||||
lockrw: make(chan struct{}, 1),
|
||||
lockc: make(chan struct{}, 1),
|
||||
bufMax: math.MaxInt, // unlimited
|
||||
}
|
||||
h.unlock()
|
||||
return h
|
||||
}
|
||||
|
||||
// lock locks h.
|
||||
func (h *connHalf) lock() {
|
||||
select {
|
||||
case <-h.lockw: // writable
|
||||
case <-h.lockr: // readable
|
||||
case <-h.lockrw: // readable and writable
|
||||
case <-h.lockc: // neither readable nor writable
|
||||
}
|
||||
}
|
||||
|
||||
// unlock unlocks h.
|
||||
func (h *connHalf) unlock() {
|
||||
canRead := h.canReadLocked()
|
||||
canWrite := h.canWriteLocked()
|
||||
switch {
|
||||
case canRead && canWrite:
|
||||
h.lockrw <- struct{}{} // readable and writable
|
||||
case canRead:
|
||||
h.lockr <- struct{}{} // readable
|
||||
case canWrite:
|
||||
h.lockw <- struct{}{} // writable
|
||||
default:
|
||||
h.lockc <- struct{}{} // neither readable nor writable
|
||||
}
|
||||
}
|
||||
|
||||
func (h *connHalf) canRead() bool {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
return h.canReadLocked()
|
||||
}
|
||||
|
||||
func (h *connHalf) canReadLocked() bool {
|
||||
return h.readErr != nil || h.readDeadline.expired || h.buf.Len() > 0 || h.readClosed || h.writeClosed
|
||||
}
|
||||
|
||||
func (h *connHalf) canWriteLocked() bool {
|
||||
return h.writeErr != nil || h.writeDeadline.expired || h.bufMax > h.buf.Len() || h.readClosed || h.writeClosed
|
||||
}
|
||||
|
||||
// waitAndLockForRead waits until h is readable and locks it.
|
||||
func (h *connHalf) waitAndLockForRead() {
|
||||
select {
|
||||
case <-h.lockr:
|
||||
// readable
|
||||
case <-h.lockrw:
|
||||
// readable and writable
|
||||
}
|
||||
}
|
||||
|
||||
// waitAndLockForWrite waits until h is writable and locks it.
|
||||
func (h *connHalf) waitAndLockForWrite() {
|
||||
select {
|
||||
case <-h.lockw:
|
||||
// writable
|
||||
case <-h.lockrw:
|
||||
// readable and writable
|
||||
}
|
||||
}
|
||||
|
||||
func (h *connHalf) read(b []byte) (n int, err error) {
|
||||
h.waitAndLockForRead()
|
||||
defer h.unlock()
|
||||
if h.readClosed {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if h.readDeadline.expired {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
if h.buf.Len() > 0 {
|
||||
return h.buf.Read(b)
|
||||
}
|
||||
if h.writeClosed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return 0, h.readErr
|
||||
}
|
||||
|
||||
func (h *connHalf) setBufferSize(size int) {
|
||||
h.lock()
|
||||
defer h.unlock()
|
||||
h.bufMax = size
|
||||
}
|
||||
|
||||
func (h *connHalf) write(b []byte) (n int, err error) {
|
||||
for n < len(b) {
|
||||
nn, err := h.writePartial(b[n:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (h *connHalf) writePartial(b []byte) (n int, err error) {
|
||||
h.waitAndLockForWrite()
|
||||
defer h.unlock()
|
||||
if h.writeClosed {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if h.writeDeadline.expired {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
if h.readClosed {
|
||||
return 0, errClosedByPeer
|
||||
}
|
||||
if h.writeErr != nil {
|
||||
return 0, h.writeErr
|
||||
}
|
||||
writeMax := h.bufMax - h.buf.Len()
|
||||
if writeMax < len(b) {
|
||||
b = b[:writeMax]
|
||||
}
|
||||
return h.buf.Write(b)
|
||||
}
|
||||
|
||||
type connDeadline struct {
|
||||
timer *time.Timer
|
||||
expired bool
|
||||
}
|
||||
|
||||
type locker interface {
|
||||
lock()
|
||||
unlock()
|
||||
}
|
||||
|
||||
func (d *connDeadline) setDeadline(mu locker, t time.Time) {
|
||||
mu.lock()
|
||||
defer mu.unlock()
|
||||
if d.timer != nil {
|
||||
d.timer.Stop()
|
||||
d.timer = nil
|
||||
}
|
||||
if t.IsZero() {
|
||||
// No deadline.
|
||||
d.expired = false
|
||||
return
|
||||
}
|
||||
expiry := time.Until(t)
|
||||
if expiry <= 0 {
|
||||
// Deadline has already passed.
|
||||
d.expired = true
|
||||
return
|
||||
}
|
||||
// Deadline is in the future.
|
||||
d.expired = false
|
||||
var timer *time.Timer
|
||||
timer = time.AfterFunc(expiry, func() {
|
||||
mu.lock()
|
||||
defer mu.unlock()
|
||||
if d.timer == timer {
|
||||
d.timer = nil
|
||||
d.expired = true
|
||||
}
|
||||
})
|
||||
d.timer = timer
|
||||
}
|
||||
450
src/internal/nettest/conn_test.go
Normal file
450
src/internal/nettest/conn_test.go
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"internal/nettest"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConnReadWrite(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
cliConn, srvConn := nettest.NewConnPair()
|
||||
|
||||
cliData := []byte("hello")
|
||||
srvData := []byte("HELLO")
|
||||
if n, err := cliConn.Write(cliData); n != len(cliData) || err != nil {
|
||||
t.Fatalf("cliConn.Write(%q) = %v, %v; want %v, nil", cliData, n, err, len(cliData))
|
||||
}
|
||||
if err := cliConn.CloseWrite(); err != nil {
|
||||
t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
|
||||
}
|
||||
if n, err := srvConn.Write(srvData); n != len(srvData) || err != nil {
|
||||
t.Fatalf("srvConn.Write(%q) = %v, %v; want %v, nil", srvData, n, err, len(srvData))
|
||||
}
|
||||
if err := srvConn.CloseWrite(); err != nil {
|
||||
t.Fatalf("cliConn.CloseWrite() = %v, want nil", err)
|
||||
}
|
||||
gotCli, err := io.ReadAll(cliConn)
|
||||
if !bytes.Equal(gotCli, srvData) || err != nil {
|
||||
t.Fatalf("io.ReadAll(cliConn) = %q, %v; want %v, nil", gotCli, err, srvData)
|
||||
}
|
||||
gotSrv, err := io.ReadAll(srvConn)
|
||||
if !bytes.Equal(gotSrv, cliData) || err != nil {
|
||||
t.Fatalf("io.ReadAll(srvConn) = %q, %v; want %v, nil", gotSrv, err, cliData)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnZeroBuffer(t *testing.T) {
|
||||
// Exercise the case where one side of the conn is blocked writing and the
|
||||
// other side is blocked reading.
|
||||
// This can only happen when the read buffer has been set to 0, blocking all writes.
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
rconn.SetReadBufferSize(0)
|
||||
var readDone, writeDone bool
|
||||
go func() {
|
||||
rconn.Read(make([]byte, 100))
|
||||
readDone = true
|
||||
}()
|
||||
go func() {
|
||||
wconn.Write([]byte("a"))
|
||||
writeDone = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if readDone || writeDone {
|
||||
t.Errorf("before unblocking: readDone=%v, writeDone=%v; want false", readDone, writeDone)
|
||||
}
|
||||
wconn.Close()
|
||||
synctest.Wait()
|
||||
if !readDone || !writeDone {
|
||||
t.Errorf("after unblocking: readDone=%v, writeDone=%v; want true", readDone, writeDone)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnPartialWrite(t *testing.T) {
|
||||
// A blocking write to a conn successfully writes some, but not all data.
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
const readSize = 5
|
||||
data := []byte("0123456789")
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
rconn.SetReadBufferSize(1)
|
||||
go func() {
|
||||
got := make([]byte, readSize)
|
||||
if n, err := io.ReadFull(rconn, got); n != readSize || err != nil {
|
||||
t.Errorf("io.ReadFull() = %v, %v; want %v, nil", n, err, readSize)
|
||||
}
|
||||
if want := data[:readSize]; !bytes.Equal(got, want) {
|
||||
t.Errorf("read %q, want %q", got, want)
|
||||
}
|
||||
rconn.Close()
|
||||
}()
|
||||
n, err := wconn.Write(data)
|
||||
if n != readSize+1 || err == nil {
|
||||
t.Errorf("Write() = %v, %v; want %v, error", n, err, readSize+1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnReadDeadline(t *testing.T) {
|
||||
for _, unblock := range []struct {
|
||||
name string
|
||||
f func(*nettest.Conn)
|
||||
}{{
|
||||
name: "Write",
|
||||
f: func(c *nettest.Conn) {
|
||||
c.Write([]byte("x"))
|
||||
},
|
||||
}, {
|
||||
name: "Close",
|
||||
f: func(c *nettest.Conn) {
|
||||
c.Close()
|
||||
},
|
||||
}, {
|
||||
name: "CloseWrite",
|
||||
f: func(c *nettest.Conn) {
|
||||
c.CloseWrite()
|
||||
},
|
||||
}} {
|
||||
for _, setDeadline := range []struct {
|
||||
name string
|
||||
f func(*nettest.Conn, time.Time) error
|
||||
}{{
|
||||
name: "SetDeadline",
|
||||
f: (*nettest.Conn).SetDeadline,
|
||||
}, {
|
||||
name: "SetReadDeadline",
|
||||
f: (*nettest.Conn).SetReadDeadline,
|
||||
}} {
|
||||
t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
|
||||
testDeadline(t, func() deadlineTest {
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
return deadlineTest{
|
||||
what: "Read()",
|
||||
block: func() error {
|
||||
_, err := rconn.Read(make([]byte, 1))
|
||||
return err
|
||||
},
|
||||
unblock: func() {
|
||||
unblock.f(wconn)
|
||||
},
|
||||
setDeadline: func(d time.Duration) {
|
||||
setDeadline.f(rconn, time.Now().Add(d))
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnWriteDeadline(t *testing.T) {
|
||||
for _, unblock := range []struct {
|
||||
name string
|
||||
f func(*nettest.Conn)
|
||||
}{{
|
||||
name: "Read",
|
||||
f: func(c *nettest.Conn) {
|
||||
io.Copy(io.Discard, c)
|
||||
},
|
||||
}, {
|
||||
name: "Close",
|
||||
f: func(c *nettest.Conn) {
|
||||
c.Close()
|
||||
},
|
||||
}, {
|
||||
name: "CloseRead",
|
||||
f: func(c *nettest.Conn) {
|
||||
c.CloseRead()
|
||||
},
|
||||
}} {
|
||||
for _, setDeadline := range []struct {
|
||||
name string
|
||||
f func(*nettest.Conn, time.Time) error
|
||||
}{{
|
||||
name: "SetDeadline",
|
||||
f: (*nettest.Conn).SetDeadline,
|
||||
}, {
|
||||
name: "SetWriteDeadline",
|
||||
f: (*nettest.Conn).SetWriteDeadline,
|
||||
}} {
|
||||
t.Run(unblock.name+"/"+setDeadline.name, func(t *testing.T) {
|
||||
testDeadline(t, func() deadlineTest {
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
rconn.SetReadBufferSize(1)
|
||||
return deadlineTest{
|
||||
what: "Write()",
|
||||
block: func() error {
|
||||
_, err := wconn.Write([]byte("1234"))
|
||||
wconn.Close()
|
||||
return err
|
||||
},
|
||||
unblock: func() {
|
||||
go unblock.f(rconn)
|
||||
},
|
||||
setDeadline: func(d time.Duration) {
|
||||
setDeadline.f(wconn, time.Now().Add(d))
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCanRead(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
if got, want := rconn.CanRead(), false; got != want {
|
||||
t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
wconn.Write([]byte("a"))
|
||||
if got, want := rconn.CanRead(), true; got != want {
|
||||
t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
rconn.Read(make([]byte, 1))
|
||||
if got, want := rconn.CanRead(), false; got != want {
|
||||
t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
wconn.Close()
|
||||
if got, want := rconn.CanRead(), true; got != want {
|
||||
t.Fatalf("after closing: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnIsClosed(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
f func() *nettest.Conn
|
||||
want bool
|
||||
}{{
|
||||
name: "unclosed",
|
||||
f: func() *nettest.Conn {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
return conn
|
||||
},
|
||||
want: false,
|
||||
}, {
|
||||
name: "closed",
|
||||
f: func() *nettest.Conn {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
conn.Close()
|
||||
return conn
|
||||
},
|
||||
want: true,
|
||||
}, {
|
||||
name: "read-closed",
|
||||
f: func() *nettest.Conn {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
conn.CloseRead()
|
||||
return conn
|
||||
},
|
||||
want: false,
|
||||
}, {
|
||||
name: "write-closed",
|
||||
f: func() *nettest.Conn {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
conn.CloseWrite()
|
||||
return conn
|
||||
},
|
||||
want: false,
|
||||
}, {
|
||||
name: "read-write-closed",
|
||||
f: func() *nettest.Conn {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
conn.CloseRead()
|
||||
conn.CloseWrite()
|
||||
return conn
|
||||
},
|
||||
want: true,
|
||||
}} {
|
||||
synctestSubtest(t, test.name, func(t *testing.T) {
|
||||
conn := test.f()
|
||||
if got, want := conn.IsClosed(), test.want; got != want {
|
||||
t.Fatalf("conn.IsClosed() = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := conn.Peer().IsClosed(), false; got != want {
|
||||
t.Fatalf("conn.Peer().IsClosed() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var anyError = errors.New("any") // anyError is passed to isOpError to match any error
|
||||
|
||||
func isOpError(err, want error) bool {
|
||||
oe, ok := err.(*net.OpError)
|
||||
return ok && (oe.Err == want || want == anyError)
|
||||
}
|
||||
|
||||
func wantConnReadBytes(t *testing.T, c *nettest.Conn, want []byte) {
|
||||
t.Helper()
|
||||
got := make([]byte, len(want))
|
||||
n, err := io.ReadFull(c, got)
|
||||
if n < len(want) || err != nil {
|
||||
t.Fatalf("io.ReadFull = %v, %v; want %v, nil", n, err, len(want))
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("io.ReadFull read %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func wantConnReadErr(t *testing.T, c *nettest.Conn, want error) {
|
||||
t.Helper()
|
||||
n, err := c.Read(make([]byte, 1))
|
||||
if want == io.EOF {
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("c.Read() = %v, %v; want 0, io.EOF", n, err)
|
||||
}
|
||||
} else {
|
||||
if n != 0 || !isOpError(err, want) {
|
||||
t.Fatalf("c.Read() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wantConnReadBlocked(t *testing.T, c *nettest.Conn) {
|
||||
done := false
|
||||
go func() {
|
||||
n, err := c.Read(make([]byte, 1))
|
||||
if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("c.Read() = %v, %v; want 0, ErrDeadlineExceeded", n, err)
|
||||
}
|
||||
done = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done {
|
||||
t.Fatalf("Read unexpectedly returned before setting deadline")
|
||||
}
|
||||
c.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
synctest.Wait()
|
||||
c.SetReadDeadline(time.Time{})
|
||||
if !done {
|
||||
t.Fatalf("Read unexpectedly did not return after setting deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnSetReadError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
rconn.SetReadError(wantErr)
|
||||
|
||||
// Consume buffer before returning error.
|
||||
wconn.Write([]byte("one"))
|
||||
wantConnReadBytes(t, rconn, []byte("one"))
|
||||
wantConnReadErr(t, rconn, wantErr)
|
||||
|
||||
// Write more to the buffer, suppressing error until buffer drains again.
|
||||
wconn.Write([]byte("two"))
|
||||
wantConnReadBytes(t, rconn, []byte("two"))
|
||||
wantConnReadErr(t, rconn, wantErr)
|
||||
|
||||
// Error may be cleared.
|
||||
rconn.SetReadError(nil)
|
||||
wantConnReadBlocked(t, rconn)
|
||||
|
||||
// Close overrides read error.
|
||||
rconn.SetReadError(wantErr)
|
||||
wconn.Write([]byte("three"))
|
||||
wconn.Close()
|
||||
wantConnReadBytes(t, rconn, []byte("three"))
|
||||
wantConnReadErr(t, rconn, io.EOF)
|
||||
|
||||
// Setting another read error does not override Close.
|
||||
rconn.SetReadError(nil)
|
||||
wantConnReadErr(t, rconn, io.EOF)
|
||||
rconn.SetReadError(wantErr)
|
||||
wantConnReadErr(t, rconn, io.EOF)
|
||||
|
||||
// ErrClosed takes precedence over read error.
|
||||
rconn.Close()
|
||||
wantConnReadErr(t, rconn, net.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func wantConnWriteBytes(t *testing.T, c *nettest.Conn, b []byte) {
|
||||
t.Helper()
|
||||
if n, err := c.Write(b); n != len(b) || err != nil {
|
||||
t.Fatalf("c.Write() = %v, %v; want %v, nil", n, err, len(b))
|
||||
}
|
||||
}
|
||||
|
||||
func wantConnWriteErr(t *testing.T, c *nettest.Conn, want error) {
|
||||
t.Helper()
|
||||
n, err := c.Write(make([]byte, 1))
|
||||
if n != 0 || !isOpError(err, want) {
|
||||
t.Fatalf("c.Write() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnSetWriteError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
wconn.SetWriteError(wantErr)
|
||||
|
||||
// Error blocks writes.
|
||||
wantConnWriteErr(t, wconn, wantErr)
|
||||
wantConnReadBlocked(t, rconn)
|
||||
|
||||
// Error may be cleared.
|
||||
wconn.SetWriteError(nil)
|
||||
wantConnWriteBytes(t, wconn, []byte("one"))
|
||||
|
||||
// Restoring error does not prevent reading buffered data.
|
||||
wconn.SetWriteError(wantErr)
|
||||
wantConnWriteErr(t, wconn, wantErr)
|
||||
wantConnReadBytes(t, rconn, []byte("one"))
|
||||
|
||||
// Error does not interfere with closing the conn.
|
||||
wconn.Close()
|
||||
wantConnReadErr(t, rconn, io.EOF)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnSetCloseError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
rconn, wconn := nettest.NewConnPair()
|
||||
|
||||
wconn.SetCloseError(wantErr)
|
||||
if _, err := wconn.Write([]byte("one")); err != nil {
|
||||
t.Fatalf("wconn.Write = %v, want success", err)
|
||||
}
|
||||
if err := wconn.Close(); !isOpError(err, wantErr) {
|
||||
t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
|
||||
}
|
||||
if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
|
||||
t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
|
||||
}
|
||||
wantConnReadBytes(t, rconn, []byte("one"))
|
||||
wantConnReadErr(t, rconn, io.EOF)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnCloseReadWriteError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
conn, _ := nettest.NewConnPair()
|
||||
conn.SetCloseError(errors.New("error"))
|
||||
if err := conn.CloseRead(); err != nil {
|
||||
t.Fatalf("conn.CloseRead = %v, want nil", err)
|
||||
}
|
||||
if err := conn.CloseWrite(); err != nil {
|
||||
t.Fatalf("conn.CloseRead = %v, want nil", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
148
src/internal/nettest/listener.go
Normal file
148
src/internal/nettest/listener.go
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"internal/gate"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// Listener is an in-memory test implementation of net.Listener.
|
||||
type Listener struct {
|
||||
gate gate.Gate
|
||||
queue queue[*Conn]
|
||||
closed bool
|
||||
acceptErr error
|
||||
closeErr error
|
||||
|
||||
addr net.Addr
|
||||
nextaddr netip.AddrPort
|
||||
}
|
||||
|
||||
// NewListener returns a new Listener.
|
||||
func NewListener() *Listener {
|
||||
return &Listener{
|
||||
gate: gate.New(false),
|
||||
addr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1000")),
|
||||
nextaddr: netip.MustParseAddrPort("127.0.0.1:10001"),
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// Any blocked Accept operations will be unblocked and return errors.
|
||||
func (li *Listener) Close() error {
|
||||
li.lock()
|
||||
defer li.unlock()
|
||||
err := li.closeErr
|
||||
li.closed = true
|
||||
li.acceptErr = net.ErrClosed
|
||||
li.closeErr = net.ErrClosed
|
||||
if err != nil {
|
||||
return &net.OpError{
|
||||
Op: "close",
|
||||
Net: "tcp",
|
||||
Addr: li.addr,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
//
|
||||
// The address is always a *net.TCPAddr.
|
||||
func (li *Listener) Addr() net.Addr {
|
||||
li.lock()
|
||||
defer li.unlock()
|
||||
return li.addr
|
||||
}
|
||||
|
||||
// SetAddr sets the listener's network address.
|
||||
func (li *Listener) SetAddr(addr net.Addr) {
|
||||
li.lock()
|
||||
defer li.unlock()
|
||||
li.addr = addr
|
||||
}
|
||||
|
||||
// NewConn returns a new connection to the listener.
|
||||
//
|
||||
// Accept will return the other side of the conn.
|
||||
func (li *Listener) NewConn() *Conn {
|
||||
return li.NewConnConfig(func(*Conn) {})
|
||||
}
|
||||
|
||||
// NewConnConfig returns a new connection to the listener.
|
||||
//
|
||||
// The function f is called with the new client connection.
|
||||
// After f returns, Accept will return the other side of the connection.
|
||||
//
|
||||
// For example, to create a connection from a specific IP address:
|
||||
//
|
||||
// conn := li.NewConnConfig(func(conn *nettest.Conn) {
|
||||
// conn.SetLocalAddr(net.TCPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1234")))
|
||||
// })
|
||||
func (li *Listener) NewConnConfig(f func(*Conn)) *Conn {
|
||||
li.lock()
|
||||
defer li.unlock()
|
||||
cli, srv := newConnPair(
|
||||
net.TCPAddrFromAddrPort(li.nextaddr),
|
||||
li.addr,
|
||||
)
|
||||
li.nextaddr = netip.AddrPortFrom(li.nextaddr.Addr(), li.nextaddr.Port()+1)
|
||||
f(cli)
|
||||
li.queue.push(srv)
|
||||
return cli
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
//
|
||||
// The connections returned by Accept are always [*Conn]s.
|
||||
func (li *Listener) Accept() (net.Conn, error) {
|
||||
li.gate.WaitAndLock(context.Background())
|
||||
defer li.unlock()
|
||||
if li.acceptErr != nil && li.queue.len() == 0 {
|
||||
return nil, &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: li.addr,
|
||||
Err: li.acceptErr,
|
||||
}
|
||||
}
|
||||
return li.queue.pop(), nil
|
||||
}
|
||||
|
||||
// SetAcceptError causes any currently blocked and future Accept calls to return
|
||||
// a net.OpError wrapping err.
|
||||
// Accept will return any available connections before returning the error,
|
||||
// including connections created after the error is set.
|
||||
// A nil error restores the usual behavior.
|
||||
func (li *Listener) SetAcceptError(err error) {
|
||||
li.gate.Lock()
|
||||
defer li.unlock()
|
||||
if !li.closed {
|
||||
li.acceptErr = err
|
||||
}
|
||||
}
|
||||
|
||||
// SetCloseError sets the error returned by Close.
|
||||
// Close still closes the listener.
|
||||
// A nil error restores the usual behavior.
|
||||
func (li *Listener) SetCloseError(err error) {
|
||||
li.gate.Lock()
|
||||
defer li.unlock()
|
||||
if !li.closed {
|
||||
li.closeErr = err
|
||||
}
|
||||
}
|
||||
|
||||
func (li *Listener) lock() {
|
||||
li.gate.Lock()
|
||||
}
|
||||
|
||||
func (li *Listener) unlock() {
|
||||
li.gate.Unlock(li.acceptErr != nil || li.queue.len() > 0)
|
||||
}
|
||||
212
src/internal/nettest/listener_test.go
Normal file
212
src/internal/nettest/listener_test.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"internal/nettest"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
)
|
||||
|
||||
func TestListenerNewConn(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
li := nettest.NewListener()
|
||||
defer li.Close()
|
||||
|
||||
// Create several connections in parallel.
|
||||
want := []string{"a", "b", "c"}
|
||||
for i := range len(want) {
|
||||
go func() {
|
||||
conn := li.NewConn()
|
||||
defer conn.Close()
|
||||
n, err := conn.Write([]byte(want[i]))
|
||||
if n != len(want[i]) || err != nil {
|
||||
t.Errorf("conn%v.Write() = %v, %v; want %v, nil", i, n, err, len(want[i]))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Accept the connections in parallel as well.
|
||||
got := make([]string, len(want))
|
||||
for i := range len(want) {
|
||||
go func() {
|
||||
conn, err := li.Accept()
|
||||
if err != nil {
|
||||
t.Errorf("li.Accept() = %v", err)
|
||||
}
|
||||
b, err := io.ReadAll(conn)
|
||||
if err != nil {
|
||||
t.Errorf("io.ReadAll(conn%v) = %v", i, err)
|
||||
}
|
||||
got[i] = string(b)
|
||||
}()
|
||||
}
|
||||
|
||||
synctest.Wait()
|
||||
slices.Sort(got)
|
||||
slices.Sort(want)
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("connections read %v; want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerInterruptAccept(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
li := nettest.NewListener()
|
||||
|
||||
var acceptErr error
|
||||
go func() {
|
||||
_, acceptErr = li.Accept()
|
||||
}()
|
||||
|
||||
synctest.Wait()
|
||||
if acceptErr != nil {
|
||||
t.Fatalf("li.Accept() = %v, want still running before close", acceptErr)
|
||||
}
|
||||
|
||||
li.Close()
|
||||
synctest.Wait()
|
||||
if !errors.Is(acceptErr, net.ErrClosed) {
|
||||
t.Fatalf("li.Accept() = %v, want ErrClosed", acceptErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerAddresses(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
srvaddr := netip.MustParseAddrPort("10.0.0.1:80")
|
||||
cliaddr := netip.MustParseAddrPort("10.0.0.2:1234")
|
||||
|
||||
li := nettest.NewListener()
|
||||
defer li.Close()
|
||||
|
||||
li.SetAddr(net.TCPAddrFromAddrPort(srvaddr))
|
||||
if got, want := li.Addr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
|
||||
t.Errorf("li.Addr() = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
cli := li.NewConnConfig(func(conn *nettest.Conn) {
|
||||
conn.SetLocalAddr(net.TCPAddrFromAddrPort(cliaddr))
|
||||
})
|
||||
srvc, err := li.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("li.Accept() = %v", err)
|
||||
}
|
||||
srv := srvc.(*nettest.Conn)
|
||||
|
||||
if cli.Peer() != srv {
|
||||
t.Errorf("cli.Peer() != srv; should be the same")
|
||||
}
|
||||
if srv.Peer() != cli {
|
||||
t.Errorf("cli.Peer() != srv; should be the same")
|
||||
}
|
||||
|
||||
if got, want := cli.LocalAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
|
||||
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := cli.RemoteAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
|
||||
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := srv.LocalAddr().(*net.TCPAddr).AddrPort(), srvaddr; got != want {
|
||||
t.Errorf("srv.LocalAddr() = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := srv.RemoteAddr().(*net.TCPAddr).AddrPort(), cliaddr; got != want {
|
||||
t.Errorf("cli.LocalAddr() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func wantListenerAccept(t *testing.T, li *nettest.Listener, want *nettest.Conn) {
|
||||
t.Helper()
|
||||
got, err := li.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("li.Accept() = %v, want conn", err)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("li.Accept() returned unexpected conn")
|
||||
}
|
||||
}
|
||||
|
||||
func wantListenerAcceptErr(t *testing.T, li *nettest.Listener, want error) {
|
||||
t.Helper()
|
||||
got, err := li.Accept()
|
||||
if got != nil || !isOpError(err, want) {
|
||||
t.Fatalf("li.Accept() = %p, %v; want nil, OpError{Err: %q}", got, err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func wantListenerAcceptBlocked(t *testing.T, li *nettest.Listener) {
|
||||
cancelErr := errors.New("cancel")
|
||||
done := false
|
||||
go func() {
|
||||
got, err := li.Accept()
|
||||
if got != nil || !errors.Is(err, cancelErr) {
|
||||
t.Errorf("li.Accept = %p, %v; want nil, cancelErr", got, err)
|
||||
}
|
||||
done = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done {
|
||||
t.Fatalf("Accept unexpectedly returned before canceling")
|
||||
}
|
||||
li.SetAcceptError(cancelErr)
|
||||
synctest.Wait()
|
||||
li.SetAcceptError(nil)
|
||||
if !done {
|
||||
t.Fatalf("Accept unexpectedly did not return after canceling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerSetAcceptError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
acceptErr := errors.New("accept error")
|
||||
li := nettest.NewListener()
|
||||
defer li.Close()
|
||||
li.SetAcceptError(acceptErr)
|
||||
|
||||
// Accept conns from queue before returning error.
|
||||
c1 := li.NewConn()
|
||||
wantListenerAccept(t, li, c1.Peer())
|
||||
wantListenerAcceptErr(t, li, acceptErr)
|
||||
|
||||
// Add a new conn, suppressing error until the queue is empty.
|
||||
c2 := li.NewConn()
|
||||
wantListenerAccept(t, li, c2.Peer())
|
||||
wantListenerAcceptErr(t, li, acceptErr)
|
||||
|
||||
// Error may be cleared.
|
||||
li.SetAcceptError(nil)
|
||||
wantListenerAcceptBlocked(t, li)
|
||||
|
||||
// ErrClosed takes precedence over accept error.
|
||||
li.SetAcceptError(acceptErr)
|
||||
li.Close()
|
||||
wantListenerAcceptErr(t, li, net.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerSetCloseError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
li := nettest.NewListener()
|
||||
closeErr := errors.New("close error")
|
||||
li.SetCloseError(closeErr)
|
||||
|
||||
// First close uses the user-provided error.
|
||||
if err := li.Close(); !isOpError(err, closeErr) {
|
||||
t.Fatalf("li.Close() = %v; want OpError wrapping accept error", err)
|
||||
}
|
||||
|
||||
// Repeated closes return ErrClosed.
|
||||
if err := li.Close(); !isOpError(err, net.ErrClosed) {
|
||||
t.Fatalf("li.Close() = %v; want OpError wrapping net.ErrClosed", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
134
src/internal/nettest/nettest_test.go
Normal file
134
src/internal/nettest/nettest_test.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"internal/nettest"
|
||||
"net"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_ net.Conn = (*nettest.Conn)(nil)
|
||||
_ net.Listener = (*nettest.Listener)(nil)
|
||||
_ net.PacketConn = (*nettest.PacketConn)(nil)
|
||||
)
|
||||
|
||||
func synctestSubtest(t *testing.T, name string, f func(t *testing.T)) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
f(t)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// A deadlineTest describes an operation which blocks until a deadline,
|
||||
// and a separate operation which unblocks it.
|
||||
type deadlineTest struct {
|
||||
what string // name of the blocking op; e.g., "Read"
|
||||
block func() error // blocking op; e.g., reading from a conn
|
||||
unblock func() // unblocking op; e.g. writing to the other side
|
||||
setDeadline func(d time.Duration) // deadline func; e.g., SetReadDeadline
|
||||
}
|
||||
|
||||
// testDeadline tests a variety of scenarios involving deadlines.
|
||||
func testDeadline(t *testing.T, setup func() deadlineTest) {
|
||||
synctestSubtest(t, "no deadline", func(t *testing.T) {
|
||||
test := setup()
|
||||
test.unblock()
|
||||
synctest.Wait()
|
||||
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
|
||||
}
|
||||
})
|
||||
synctestSubtest(t, "unblock before setdeadline", func(t *testing.T) {
|
||||
test := setup()
|
||||
test.unblock()
|
||||
synctest.Wait()
|
||||
test.setDeadline(5 * time.Second)
|
||||
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
|
||||
}
|
||||
})
|
||||
synctestSubtest(t, "unblock after blocking", func(t *testing.T) {
|
||||
test := setup()
|
||||
test.setDeadline(5 * time.Second)
|
||||
var done bool
|
||||
go func() {
|
||||
if err := test.block(); errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want not deadline exceeded", test.what, err)
|
||||
}
|
||||
done = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done {
|
||||
t.Fatalf("%v: unexpectedly returned before unblocking", test.what)
|
||||
}
|
||||
test.unblock()
|
||||
synctest.Wait()
|
||||
if !done {
|
||||
t.Fatalf("%v: did not return after unblocking", test.what)
|
||||
}
|
||||
})
|
||||
synctestSubtest(t, "deadline expires", func(t *testing.T) {
|
||||
test := setup()
|
||||
start := time.Now()
|
||||
const delay = 5 * time.Second
|
||||
test.setDeadline(delay)
|
||||
var done atomic.Bool
|
||||
go func() {
|
||||
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
|
||||
}
|
||||
if got, want := time.Since(start), delay; got != want {
|
||||
t.Errorf("%v: returned after %v, want %v", test.what, got, want)
|
||||
}
|
||||
done.Store(true)
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done.Load() {
|
||||
t.Fatalf("%v: unexpectedly returned before unblocking", test.what)
|
||||
}
|
||||
time.Sleep(delay)
|
||||
synctest.Wait()
|
||||
if !done.Load() {
|
||||
t.Fatalf("%v: did not return after deadline", test.what)
|
||||
}
|
||||
})
|
||||
synctestSubtest(t, "deadline already expired", func(t *testing.T) {
|
||||
test := setup()
|
||||
test.setDeadline(-1 * time.Second)
|
||||
test.unblock()
|
||||
synctest.Wait()
|
||||
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
|
||||
}
|
||||
})
|
||||
synctestSubtest(t, "reduce deadline after blocking", func(t *testing.T) {
|
||||
test := setup()
|
||||
test.setDeadline(5 * time.Second)
|
||||
var done bool
|
||||
go func() {
|
||||
if err := test.block(); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("%v: %v, want os.ErrDeadlineExceeded", test.what, err)
|
||||
}
|
||||
done = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done {
|
||||
t.Fatalf("%v: unexpectedly returned before reducing deadline", test.what)
|
||||
}
|
||||
test.setDeadline(-1 * time.Second)
|
||||
synctest.Wait()
|
||||
if !done {
|
||||
t.Fatalf("%v: did not return after deadline", test.what)
|
||||
}
|
||||
})
|
||||
}
|
||||
250
src/internal/nettest/packetconn.go
Normal file
250
src/internal/nettest/packetconn.go
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"internal/gate"
|
||||
"net"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A PacketNet is a group of communicating [PacketConn]s.
|
||||
type PacketNet struct {
|
||||
mu sync.Mutex
|
||||
conns map[netAddr]*PacketConn
|
||||
}
|
||||
|
||||
type netAddr struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
// NewPacketNet returns a new PacketNet.
|
||||
func NewPacketNet() *PacketNet {
|
||||
return &PacketNet{
|
||||
conns: make(map[netAddr]*PacketConn),
|
||||
}
|
||||
}
|
||||
|
||||
// NewConn returns a new [PacketConn] listening on the given address.
|
||||
// It returns an error if there is an existing listener on this address.
|
||||
func (n *PacketNet) NewConn(a net.Addr) (*PacketConn, error) {
|
||||
n.mu.Lock()
|
||||
defer n.mu.Unlock()
|
||||
addrKey := netAddr{a.Network(), a.String()}
|
||||
if _, ok := n.conns[addrKey]; ok {
|
||||
return nil, &net.OpError{
|
||||
Op: "listen",
|
||||
Net: "udp",
|
||||
Addr: a,
|
||||
Err: errors.New("address is in use"),
|
||||
}
|
||||
}
|
||||
p := &PacketConn{
|
||||
gate: gate.New(false),
|
||||
addr: a,
|
||||
net: n,
|
||||
}
|
||||
n.conns[addrKey] = p
|
||||
return p, nil
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
gate gate.Gate
|
||||
queue queue[*packet]
|
||||
closed bool
|
||||
readErr error
|
||||
writeErr error
|
||||
closeErr error
|
||||
readDeadline connDeadline
|
||||
|
||||
net *PacketNet
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
type packet struct {
|
||||
b []byte
|
||||
src net.Addr
|
||||
}
|
||||
|
||||
// ReadFrom reads a packet from the connection, copying the payload into b.
|
||||
func (p *PacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||
p.gate.WaitAndLock(context.Background())
|
||||
defer p.unlock()
|
||||
|
||||
switch {
|
||||
case p.closed:
|
||||
err = net.ErrClosed
|
||||
case p.readDeadline.expired:
|
||||
err = os.ErrDeadlineExceeded
|
||||
case p.queue.len() == 0 && p.readErr != nil:
|
||||
err = p.readErr
|
||||
}
|
||||
if err != nil {
|
||||
return 0, nil, &net.OpError{
|
||||
Op: "read",
|
||||
Net: "udp",
|
||||
Addr: p.addr,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
pkt := p.queue.pop()
|
||||
n = copy(b, pkt.b)
|
||||
return n, pkt.src, nil
|
||||
}
|
||||
|
||||
// WriteTo writes a packet with payload b to addr.
|
||||
// addr must be a [*net.UDPAddr].
|
||||
//
|
||||
// WriteTo appends the packet to the recipient's receive buffer.
|
||||
// If no recipient is listening on addr or if the recipient's
|
||||
// receive buffer is full, the packet is silently discarded.
|
||||
func (p *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
p.gate.Lock()
|
||||
switch {
|
||||
case p.closed:
|
||||
err = net.ErrClosed
|
||||
case p.writeErr != nil:
|
||||
err = p.writeErr
|
||||
}
|
||||
p.unlock()
|
||||
if err != nil {
|
||||
return 0, &net.OpError{
|
||||
Op: "write",
|
||||
Net: "udp",
|
||||
Source: p.addr,
|
||||
Addr: addr,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
p.net.mu.Lock()
|
||||
dst := p.net.conns[netAddr{addr.Network(), addr.String()}]
|
||||
p.net.mu.Unlock()
|
||||
|
||||
if dst == nil {
|
||||
// There is no PacketConn listening on the destination address,
|
||||
// and the packet falls silently into the void.
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
dst.lock()
|
||||
if !dst.closed {
|
||||
dst.queue.push(&packet{b: slices.Clone(b), src: p.addr})
|
||||
}
|
||||
dst.unlock()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (p *PacketConn) Close() error {
|
||||
p.net.mu.Lock()
|
||||
delete(p.net.conns, netAddr{p.addr.Network(), p.addr.String()})
|
||||
p.net.mu.Unlock()
|
||||
|
||||
p.lock()
|
||||
defer p.unlock()
|
||||
err := p.closeErr
|
||||
p.closed = true
|
||||
p.readErr = net.ErrClosed
|
||||
p.writeErr = net.ErrClosed
|
||||
p.closeErr = net.ErrClosed
|
||||
if err != nil {
|
||||
return &net.OpError{
|
||||
Op: "close",
|
||||
Net: "udp",
|
||||
Addr: p.addr,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// LocalAddr returns the (fake) local network address.
|
||||
func (p *PacketConn) LocalAddr() net.Addr {
|
||||
p.lock()
|
||||
defer p.unlock()
|
||||
return p.addr
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for the connection.
|
||||
// PacketConns have no write deadline.
|
||||
func (p *PacketConn) SetDeadline(t time.Time) error {
|
||||
return p.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline for the connection.
|
||||
func (p *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
p.readDeadline.setDeadline(p, t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline has no effect.
|
||||
// Writes to PacketConns never block.
|
||||
func (p *PacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadError causes any currently blocked and future ReadFrom calls to return
|
||||
// a net.OpError wrapping err. It does not affect the other half of the connection.
|
||||
// Reads will return any buffered data before returning the error,
|
||||
// including data written after the error is set.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *PacketConn) SetReadError(err error) {
|
||||
c.lock()
|
||||
defer c.unlock()
|
||||
c.readErr = err
|
||||
}
|
||||
|
||||
// SetWriteError causes any currently blocked and future WriteTo calls to return
|
||||
// a net.OpError wrapping err. It does not affect the other half of the connection.
|
||||
// Writes will not write data while an error is set.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *PacketConn) SetWriteError(err error) {
|
||||
c.lock()
|
||||
defer c.unlock()
|
||||
c.writeErr = err
|
||||
}
|
||||
|
||||
// SetCloseError sets the error returned by Close.
|
||||
// Close still closes the connection.
|
||||
// A nil error restores the usual behavior.
|
||||
func (c *PacketConn) SetCloseError(err error) {
|
||||
c.lock()
|
||||
defer c.unlock()
|
||||
c.closeErr = err
|
||||
}
|
||||
|
||||
// CanRead reports whether [ReadFrom] can return at least one byte or an error.
|
||||
// If [ReadFrom] would block, CanRead returns false.
|
||||
func (p *PacketConn) CanRead() bool {
|
||||
p.lock()
|
||||
defer p.unlock()
|
||||
return p.canReadLocked()
|
||||
}
|
||||
|
||||
func (p *PacketConn) canReadLocked() bool {
|
||||
return p.queue.len() > 0 || p.readDeadline.expired || p.closed || p.readErr != nil
|
||||
}
|
||||
|
||||
// IsClosed reports whether the connection has been closed.
|
||||
func (p *PacketConn) IsClosed() bool {
|
||||
p.lock()
|
||||
defer p.unlock()
|
||||
return p.closed
|
||||
}
|
||||
|
||||
func (p *PacketConn) lock() {
|
||||
p.gate.Lock()
|
||||
}
|
||||
|
||||
func (p *PacketConn) unlock() {
|
||||
p.gate.Unlock(p.canReadLocked())
|
||||
}
|
||||
329
src/internal/nettest/packetconn_test.go
Normal file
329
src/internal/nettest/packetconn_test.go
Normal file
|
|
@ -0,0 +1,329 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"internal/nettest"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPacketConnListenConflict(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
addr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1000"))
|
||||
pnet := nettest.NewPacketNet()
|
||||
conn, err := pnet.NewConn(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("with no existing listener, pnet.NewConn(%v) = %v; want success", addr, err)
|
||||
}
|
||||
_, err = pnet.NewConn(addr)
|
||||
if err == nil {
|
||||
t.Fatalf("with existing listener, pnet.NewConn(%v) = nil; want error", addr)
|
||||
}
|
||||
conn.Close()
|
||||
_, err = pnet.NewConn(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("after closing existing listener, pnet.NewConn(%v) = %v; want success", addr, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnReadWrite(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
c1 := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
c2 := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
c3 := mustNewPacketConn(t, pnet, "10.0.0.3:3000")
|
||||
|
||||
wantPacketConnWriteTo(t, c1, []byte("1->3"), c3.LocalAddr())
|
||||
wantPacketConnWriteTo(t, c2, []byte("2->3"), c3.LocalAddr())
|
||||
wantPacketConnWriteTo(t, c3, []byte("3->1"), c1.LocalAddr())
|
||||
|
||||
wantPacketConnReadBytes(t, c1, []byte("3->1"), c3.LocalAddr())
|
||||
wantPacketConnReadBytes(t, c3, []byte("1->3"), c1.LocalAddr())
|
||||
wantPacketConnReadBytes(t, c3, []byte("2->3"), c2.LocalAddr())
|
||||
wantPacketConnReadBlocked(t, c1)
|
||||
wantPacketConnReadBlocked(t, c2)
|
||||
wantPacketConnReadBlocked(t, c3)
|
||||
|
||||
// Write a packet into the void (no listener on this address).
|
||||
wantPacketConnWriteTo(t, c1, []byte("1->lost"), net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.100:1000")))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnWriteAddressErrors(t *testing.T) {
|
||||
t.Skip("TODO: figure out if these should be errors")
|
||||
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
c4 := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
c6 := mustNewPacketConn(t, pnet, "[::1]:1000")
|
||||
|
||||
wantPacketConnWriteErr(t, c4, c6.LocalAddr(), anyError) // IPv4 -> IPv6
|
||||
wantPacketConnWriteErr(t, c6, c4.LocalAddr(), anyError) // IPv6 -> IPv4
|
||||
|
||||
// Not a *net.UDPAddr.
|
||||
wantPacketConnWriteErr(t, c4, net.UDPAddrFromAddrPort(netip.MustParseAddrPort("10.0.0.1:1000")), anyError)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnClose(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
pconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
|
||||
wantPacketConnWriteTo(t, pconn, []byte("hello"), pconn.LocalAddr())
|
||||
|
||||
if err := pconn.Close(); err != nil {
|
||||
t.Errorf("pconn.Close() = %v, want success", err)
|
||||
}
|
||||
if err := pconn.Close(); !isOpError(err, net.ErrClosed) {
|
||||
t.Errorf("pconn.Close() = %v, want ErrClosed", err)
|
||||
}
|
||||
|
||||
wantPacketConnReadErr(t, pconn, net.ErrClosed)
|
||||
wantPacketConnWriteErr(t, pconn, pconn.LocalAddr(), net.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnReadDeadline(t *testing.T) {
|
||||
for _, setDeadline := range []struct {
|
||||
name string
|
||||
f func(*nettest.PacketConn, time.Time) error
|
||||
}{{
|
||||
name: "SetDeadline",
|
||||
f: (*nettest.PacketConn).SetDeadline,
|
||||
}, {
|
||||
name: "SetReadDeadline",
|
||||
f: (*nettest.PacketConn).SetReadDeadline,
|
||||
}} {
|
||||
t.Run(setDeadline.name, func(t *testing.T) {
|
||||
testDeadline(t, func() deadlineTest {
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
return deadlineTest{
|
||||
what: "ReadFrom()",
|
||||
block: func() error {
|
||||
_, _, err := rconn.ReadFrom(make([]byte, 1))
|
||||
return err
|
||||
},
|
||||
unblock: func() {
|
||||
wconn.WriteTo([]byte("x"), rconn.LocalAddr())
|
||||
},
|
||||
setDeadline: func(d time.Duration) {
|
||||
setDeadline.f(rconn, time.Now().Add(d))
|
||||
},
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPacketConnWriteDeadline(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
|
||||
// This does nothing, even though the deadline has expired.
|
||||
wconn.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||
|
||||
wantPacketConnWriteTo(t, wconn, []byte("data"), rconn.LocalAddr())
|
||||
wantPacketConnReadBytes(t, rconn, []byte("data"), wconn.LocalAddr())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnSetReadError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
rconn.SetReadError(wantErr)
|
||||
|
||||
// Consume buffer before returning error.
|
||||
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
|
||||
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
|
||||
wantPacketConnReadErr(t, rconn, wantErr)
|
||||
|
||||
// Write more, suppressing error until buffer drains again.
|
||||
wantPacketConnWriteTo(t, wconn, []byte("two"), rconn.LocalAddr())
|
||||
wantPacketConnReadBytes(t, rconn, []byte("two"), wconn.LocalAddr())
|
||||
wantPacketConnReadErr(t, rconn, wantErr)
|
||||
|
||||
// Error may be cleared.
|
||||
rconn.SetReadError(nil)
|
||||
wantPacketConnReadBlocked(t, rconn)
|
||||
|
||||
// ErrClosed takes precedence over read error.
|
||||
rconn.SetReadError(wantErr)
|
||||
rconn.Close()
|
||||
wantPacketConnReadErr(t, rconn, net.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnSetWriteError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
wconn.SetWriteError(wantErr)
|
||||
|
||||
// Error blocks writes.
|
||||
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), wantErr)
|
||||
wantPacketConnReadBlocked(t, rconn)
|
||||
|
||||
// Error may be cleared.
|
||||
wconn.SetWriteError(nil)
|
||||
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
|
||||
|
||||
// Restoring error does not prevent reading buffered data.
|
||||
wconn.SetWriteError(wantErr)
|
||||
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), wantErr)
|
||||
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
|
||||
|
||||
// Error does not interfere with closing the conn.
|
||||
wconn.Close()
|
||||
wantPacketConnWriteErr(t, wconn, rconn.LocalAddr(), net.ErrClosed)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnSetCloseError(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
wantErr := errors.New("error")
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
wconn.SetCloseError(wantErr)
|
||||
|
||||
wantPacketConnWriteTo(t, wconn, []byte("one"), rconn.LocalAddr())
|
||||
if err := wconn.Close(); !isOpError(err, wantErr) {
|
||||
t.Fatalf("wconn.Close = %v, want OpError{Err: %v}", err, wantErr)
|
||||
}
|
||||
if err := wconn.Close(); !isOpError(err, net.ErrClosed) {
|
||||
t.Fatalf("wconn.Close = %v, want OpError{Err: net.ErrClosed}", err)
|
||||
}
|
||||
wantPacketConnReadBytes(t, rconn, []byte("one"), wconn.LocalAddr())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnCanRead(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
rconn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
wconn := mustNewPacketConn(t, pnet, "10.0.0.2:2000")
|
||||
if got, want := rconn.CanRead(), false; got != want {
|
||||
t.Fatalf("before writing data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
wconn.WriteTo([]byte("a"), rconn.LocalAddr())
|
||||
if got, want := rconn.CanRead(), true; got != want {
|
||||
t.Fatalf("after writing data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
rconn.ReadFrom(make([]byte, 1))
|
||||
if got, want := rconn.CanRead(), false; got != want {
|
||||
t.Fatalf("after reading data: rconn.CanRead() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPacketConnIsClosed(t *testing.T) {
|
||||
synctest.Test(t, func(t *testing.T) {
|
||||
pnet := nettest.NewPacketNet()
|
||||
conn := mustNewPacketConn(t, pnet, "10.0.0.1:1000")
|
||||
if got, want := conn.IsClosed(), false; got != want {
|
||||
t.Fatalf("before closing: conn.IsClosed() = %v, want %v", got, want)
|
||||
}
|
||||
conn.Close()
|
||||
if got, want := conn.IsClosed(), true; got != want {
|
||||
t.Fatalf("after closing: conn.IsClosed() = %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func mustNewPacketConn(t *testing.T, pnet *nettest.PacketNet, addr string) *nettest.PacketConn {
|
||||
t.Helper()
|
||||
c, err := pnet.NewConn(net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addr)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func wantPacketConnWriteTo(t *testing.T, c *nettest.PacketConn, b []byte, dst net.Addr) {
|
||||
t.Helper()
|
||||
if n, err := c.WriteTo(b, dst); n != len(b) || err != nil {
|
||||
t.Fatalf("conn.WriteTo(%q, %v) = %v, %v; want %v, nil", b, dst, n, err, len(b))
|
||||
}
|
||||
}
|
||||
|
||||
func wantPacketConnWriteErr(t *testing.T, c *nettest.PacketConn, dst net.Addr, want error) {
|
||||
t.Helper()
|
||||
n, err := c.WriteTo(make([]byte, 1), dst)
|
||||
if n != 0 || !isOpError(err, want) {
|
||||
t.Fatalf("c.WriteTo() = %v, %v; want 0, OpError{Err: %q}", n, err, want)
|
||||
}
|
||||
}
|
||||
|
||||
func wantPacketConnReadBytes(t *testing.T, c *nettest.PacketConn, want []byte, wantAddr net.Addr) {
|
||||
t.Helper()
|
||||
udpWantAddr, ok := wantAddr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
t.Fatalf("wantAddr is %T, should be *net.UDPAddr", wantAddr)
|
||||
}
|
||||
got := make([]byte, len(want)+1)
|
||||
n, addr, err := c.ReadFrom(got)
|
||||
got = got[:n]
|
||||
udpAddr, addrOK := addr.(*net.UDPAddr)
|
||||
if n != len(want) || !addrOK || udpAddr.AddrPort() != udpWantAddr.AddrPort() {
|
||||
t.Fatalf("conn.ReadFrom() = %v, %v, %v; want %v, %v, nil", n, addr, err, len(want), wantAddr)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("conn.ReadFrom() read %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func wantPacketConnReadErr(t *testing.T, c *nettest.PacketConn, want error) {
|
||||
t.Helper()
|
||||
n, addr, err := c.ReadFrom(make([]byte, 1))
|
||||
if want == io.EOF {
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("c.ReadFrom() = %v, %v, %v; want 0, nil, io.EOF", n, addr, err)
|
||||
}
|
||||
} else {
|
||||
if n != 0 || !isOpError(err, want) {
|
||||
t.Fatalf("c.ReadFrom() = %v, %v, %v; want 0, nil, OpError{Err: %q}", n, addr, err, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wantPacketConnReadBlocked(t *testing.T, c *nettest.PacketConn) {
|
||||
done := false
|
||||
go func() {
|
||||
n, addr, err := c.ReadFrom(make([]byte, 1))
|
||||
if n != 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("c.Read() = %v, %v, %v; want 0, nil, ErrDeadlineExceeded", n, addr, err)
|
||||
}
|
||||
done = true
|
||||
}()
|
||||
synctest.Wait()
|
||||
if done {
|
||||
t.Fatalf("ReadFrom unexpectedly returned before setting deadline")
|
||||
}
|
||||
c.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
synctest.Wait()
|
||||
c.SetReadDeadline(time.Time{})
|
||||
if !done {
|
||||
t.Fatalf("ReadFrom unexpectedly did not return after setting deadline")
|
||||
}
|
||||
}
|
||||
33
src/internal/nettest/queue.go
Normal file
33
src/internal/nettest/queue.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright 2026 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package nettest
|
||||
|
||||
type queue[T any] struct {
|
||||
headPos int
|
||||
head []T
|
||||
tail []T
|
||||
}
|
||||
|
||||
func (q *queue[T]) len() int {
|
||||
return len(q.head[q.headPos:]) + len(q.tail)
|
||||
}
|
||||
|
||||
func (q *queue[T]) push(v T) {
|
||||
q.tail = append(q.tail, v)
|
||||
}
|
||||
|
||||
func (q *queue[T]) pop() T {
|
||||
var zero T
|
||||
if q.headPos >= len(q.head) {
|
||||
if len(q.tail) == 0 {
|
||||
return zero
|
||||
}
|
||||
q.head, q.headPos, q.tail = q.tail, 0, q.head[:0]
|
||||
}
|
||||
v := q.head[q.headPos]
|
||||
q.head[q.headPos] = zero
|
||||
q.headPos++
|
||||
return v
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue