caddy/network_test.go

964 lines
22 KiB
Go
Raw Permalink Normal View History

// Copyright 2015 Matthew Holt and The Caddy Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package caddy
import (
"context"
"fmt"
"net"
"strings"
"sync"
"testing"
"time"
)
func TestNetworkAddress_String_Consistency(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
}{
{
name: "basic tcp",
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8080},
},
{
name: "tcp with port range",
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8090},
},
{
name: "unix socket",
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
},
{
name: "udp",
addr: NetworkAddress{Network: "udp", Host: "0.0.0.0", StartPort: 53, EndPort: 53},
},
{
name: "ipv6",
addr: NetworkAddress{Network: "tcp", Host: "::1", StartPort: 80, EndPort: 80},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
str := test.addr.String()
// Parse the string back
parsed, err := ParseNetworkAddress(str)
if err != nil {
t.Fatalf("Failed to parse string representation: %v", err)
}
// Should be equivalent to original
if parsed.Network != test.addr.Network {
t.Errorf("Network mismatch: expected %s, got %s", test.addr.Network, parsed.Network)
}
if parsed.Host != test.addr.Host {
t.Errorf("Host mismatch: expected %s, got %s", test.addr.Host, parsed.Host)
}
if parsed.StartPort != test.addr.StartPort {
t.Errorf("StartPort mismatch: expected %d, got %d", test.addr.StartPort, parsed.StartPort)
}
if parsed.EndPort != test.addr.EndPort {
t.Errorf("EndPort mismatch: expected %d, got %d", test.addr.EndPort, parsed.EndPort)
}
})
}
}
func TestNetworkAddress_PortRangeSize_EdgeCases(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
expected uint
}{
{
name: "single port",
addr: NetworkAddress{StartPort: 80, EndPort: 80},
expected: 1,
},
{
name: "invalid range (end < start)",
addr: NetworkAddress{StartPort: 8080, EndPort: 8070},
expected: 0,
},
{
name: "zero ports",
addr: NetworkAddress{StartPort: 0, EndPort: 0},
expected: 1,
},
{
name: "maximum range",
addr: NetworkAddress{StartPort: 1, EndPort: 65535},
expected: 65535,
},
{
name: "large range",
addr: NetworkAddress{StartPort: 8000, EndPort: 9000},
expected: 1001,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
size := test.addr.PortRangeSize()
if size != test.expected {
t.Errorf("Expected %d, got %d", test.expected, size)
}
})
}
}
func TestNetworkAddress_At_Validation(t *testing.T) {
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8080,
EndPort: 8090,
}
// Test valid offsets
for offset := uint(0); offset <= 10; offset++ {
result := addr.At(offset)
expectedPort := 8080 + offset
if result.StartPort != expectedPort || result.EndPort != expectedPort {
t.Errorf("Offset %d: expected port %d, got %d-%d",
offset, expectedPort, result.StartPort, result.EndPort)
}
if result.Network != addr.Network || result.Host != addr.Host {
t.Errorf("Offset %d: network/host should be preserved", offset)
}
}
}
func TestNetworkAddress_Expand_LargeRange(t *testing.T) {
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8000,
EndPort: 8010,
}
expanded := addr.Expand()
expectedSize := 11 // 8000 to 8010 inclusive
if len(expanded) != expectedSize {
t.Errorf("Expected %d addresses, got %d", expectedSize, len(expanded))
}
// Verify each address
for i, expandedAddr := range expanded {
expectedPort := uint(8000 + i)
if expandedAddr.StartPort != expectedPort || expandedAddr.EndPort != expectedPort {
t.Errorf("Address %d: expected port %d, got %d-%d",
i, expectedPort, expandedAddr.StartPort, expandedAddr.EndPort)
}
}
}
func TestNetworkAddress_IsLoopback_EdgeCases(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
expected bool
}{
{
name: "unix socket",
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
expected: true, // Unix sockets are always considered loopback
},
{
name: "fd network",
addr: NetworkAddress{Network: "fd", Host: "3"},
expected: true, // fd networks are always considered loopback
},
{
name: "localhost",
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
expected: true,
},
{
name: "127.0.0.1",
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.1"},
expected: true,
},
{
name: "::1",
addr: NetworkAddress{Network: "tcp", Host: "::1"},
expected: true,
},
{
name: "127.0.0.2",
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.2"},
expected: true, // Part of 127.0.0.0/8 loopback range
},
{
name: "192.168.1.1",
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
expected: false, // Private but not loopback
},
{
name: "invalid ip",
addr: NetworkAddress{Network: "tcp", Host: "invalid-ip"},
expected: false,
},
{
name: "empty host",
addr: NetworkAddress{Network: "tcp", Host: ""},
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.addr.isLoopback()
if result != test.expected {
t.Errorf("Expected %v, got %v", test.expected, result)
}
})
}
}
func TestNetworkAddress_IsWildcard_EdgeCases(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
expected bool
}{
{
name: "empty host",
addr: NetworkAddress{Network: "tcp", Host: ""},
expected: true,
},
{
name: "ipv4 any",
addr: NetworkAddress{Network: "tcp", Host: "0.0.0.0"},
expected: true,
},
{
name: "ipv6 any",
addr: NetworkAddress{Network: "tcp", Host: "::"},
expected: true,
},
{
name: "localhost",
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
expected: false,
},
{
name: "specific ip",
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
expected: false,
},
{
name: "invalid ip",
addr: NetworkAddress{Network: "tcp", Host: "invalid"},
expected: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.addr.isWildcardInterface()
if result != test.expected {
t.Errorf("Expected %v, got %v", test.expected, result)
}
})
}
}
func TestSplitNetworkAddress_IPv6_EdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expectNetwork string
expectHost string
expectPort string
expectErr bool
}{
{
name: "ipv6 with port",
input: "[::1]:8080",
expectHost: "::1",
expectPort: "8080",
},
{
name: "ipv6 without port",
input: "[::1]",
expectHost: "::1",
},
{
name: "ipv6 without brackets or port",
input: "::1",
expectHost: "::1",
},
{
name: "ipv6 loopback",
input: "[::1]:443",
expectHost: "::1",
expectPort: "443",
},
{
name: "ipv6 any address",
input: "[::]:80",
expectHost: "::",
expectPort: "80",
},
{
name: "ipv6 with network prefix",
input: "tcp6/[::1]:8080",
expectNetwork: "tcp6",
expectHost: "::1",
expectPort: "8080",
},
{
name: "malformed ipv6",
input: "[::1:8080", // Missing closing bracket
expectHost: "::1:8080",
},
{
name: "ipv6 with zone",
input: "[fe80::1%eth0]:8080",
expectHost: "fe80::1%eth0",
expectPort: "8080",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
network, host, port, err := SplitNetworkAddress(test.input)
if test.expectErr && err == nil {
t.Error("Expected error but got none")
}
if !test.expectErr && err != nil {
t.Errorf("Unexpected error: %v", err)
}
if network != test.expectNetwork {
t.Errorf("Network: expected '%s', got '%s'", test.expectNetwork, network)
}
if host != test.expectHost {
t.Errorf("Host: expected '%s', got '%s'", test.expectHost, host)
}
if port != test.expectPort {
t.Errorf("Port: expected '%s', got '%s'", test.expectPort, port)
}
})
}
}
func TestParseNetworkAddress_PortRange_Validation(t *testing.T) {
tests := []struct {
name string
input string
expectErr bool
errMsg string
}{
{
name: "valid range",
input: "localhost:8080-8090",
expectErr: false,
},
{
name: "inverted range",
input: "localhost:8090-8080",
expectErr: true,
errMsg: "end port must not be less than start port",
},
{
name: "too large range",
input: "localhost:0-65535",
expectErr: true,
errMsg: "port range exceeds 65535 ports",
},
{
name: "invalid start port",
input: "localhost:abc-8080",
expectErr: true,
},
{
name: "invalid end port",
input: "localhost:8080-xyz",
expectErr: true,
},
{
name: "port too large",
input: "localhost:99999",
expectErr: true,
},
{
name: "negative port",
input: "localhost:-80",
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
_, err := ParseNetworkAddress(test.input)
if test.expectErr && err == nil {
t.Error("Expected error but got none")
}
if !test.expectErr && err != nil {
t.Errorf("Unexpected error: %v", err)
}
if test.expectErr && test.errMsg != "" && err != nil {
if !containsString(err.Error(), test.errMsg) {
t.Errorf("Expected error containing '%s', got '%s'", test.errMsg, err.Error())
}
}
})
}
}
func TestNetworkAddress_Listen_ContextCancellation(t *testing.T) {
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 0, // Let OS assign port
EndPort: 0,
}
// Create context that will be cancelled
ctx, cancel := context.WithCancel(context.Background())
// Start listening in a goroutine
listenDone := make(chan error, 1)
go func() {
_, err := addr.Listen(ctx, 0, net.ListenConfig{})
listenDone <- err
}()
// Cancel context immediately
cancel()
// Should get context cancellation error quickly
select {
case err := <-listenDone:
if err == nil {
t.Error("Expected error due to context cancellation")
}
// Accept any error related to context cancellation
// (could be context.Canceled or DNS lookup error due to cancellation)
case <-time.After(time.Second):
t.Error("Listen operation did not respect context cancellation")
}
}
func TestNetworkAddress_ListenAll_PartialFailure(t *testing.T) {
// Create an address range where some ports might fail to bind
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 0, // OS-assigned port
EndPort: 2, // Try to bind 3 ports starting from OS-assigned
}
// This test might be flaky depending on available ports,
// but tests the error handling logic
ctx := context.Background()
listeners, err := addr.ListenAll(ctx, net.ListenConfig{})
// Either all succeed or all fail (due to cleanup on partial failure)
if err != nil {
// If there's an error, no listeners should be returned
if len(listeners) != 0 {
t.Errorf("Expected no listeners on error, got %d", len(listeners))
}
} else {
// If successful, should have listeners for all ports in range
expectedCount := int(addr.PortRangeSize())
if len(listeners) != expectedCount {
t.Errorf("Expected %d listeners, got %d", expectedCount, len(listeners))
}
// Clean up listeners
for _, ln := range listeners {
if closer, ok := ln.(interface{ Close() error }); ok {
closer.Close()
}
}
}
}
func TestJoinNetworkAddress_SpecialCases(t *testing.T) {
tests := []struct {
name string
network string
host string
port string
expected string
}{
{
name: "empty everything",
network: "",
host: "",
port: "",
expected: "",
},
{
name: "network only",
network: "tcp",
host: "",
port: "",
expected: "tcp/",
},
{
name: "host only",
network: "",
host: "localhost",
port: "",
expected: "localhost",
},
{
name: "port only",
network: "",
host: "",
port: "8080",
expected: ":8080",
},
{
name: "unix socket with port (port ignored)",
network: "unix",
host: "/tmp/socket",
port: "8080",
expected: "unix//tmp/socket",
},
{
name: "fd network with port (port ignored)",
network: "fd",
host: "3",
port: "8080",
expected: "fd/3",
},
{
name: "ipv6 host with port",
network: "tcp",
host: "::1",
port: "8080",
expected: "tcp/[::1]:8080",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := JoinNetworkAddress(test.network, test.host, test.port)
if result != test.expected {
t.Errorf("Expected '%s', got '%s'", test.expected, result)
}
})
}
}
func TestIsUnixNetwork_IsFdNetwork(t *testing.T) {
tests := []struct {
network string
isUnix bool
isFd bool
}{
{"unix", true, false},
{"unixgram", true, false},
{"unixpacket", true, false},
{"fd", false, true},
{"fdgram", false, true},
{"tcp", false, false},
{"udp", false, false},
{"", false, false},
{"unix-like", true, false},
{"fd-like", false, true},
}
for _, test := range tests {
t.Run(test.network, func(t *testing.T) {
if IsUnixNetwork(test.network) != test.isUnix {
t.Errorf("IsUnixNetwork('%s'): expected %v, got %v",
test.network, test.isUnix, IsUnixNetwork(test.network))
}
if IsFdNetwork(test.network) != test.isFd {
t.Errorf("IsFdNetwork('%s'): expected %v, got %v",
test.network, test.isFd, IsFdNetwork(test.network))
}
// Test NetworkAddress methods too
addr := NetworkAddress{Network: test.network}
if addr.IsUnixNetwork() != test.isUnix {
t.Errorf("NetworkAddress.IsUnixNetwork(): expected %v, got %v",
test.isUnix, addr.IsUnixNetwork())
}
if addr.IsFdNetwork() != test.isFd {
t.Errorf("NetworkAddress.IsFdNetwork(): expected %v, got %v",
test.isFd, addr.IsFdNetwork())
}
})
}
}
func TestRegisterNetwork_Validation(t *testing.T) {
// Save original state
originalNetworkTypes := make(map[string]ListenerFunc)
for k, v := range networkTypes {
originalNetworkTypes[k] = v
}
defer func() {
// Restore original state
networkTypes = originalNetworkTypes
}()
mockListener := func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) {
return nil, nil
}
// Test reserved network types that should panic
reservedTypes := []string{
"tcp", "tcp4", "tcp6",
"udp", "udp4", "udp6",
"unix", "unixpacket", "unixgram",
"ip:1", "ip4:1", "ip6:1",
"fd", "fdgram",
}
for _, networkType := range reservedTypes {
t.Run("reserved_"+networkType, func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected panic for reserved network type: %s", networkType)
}
}()
RegisterNetwork(networkType, mockListener)
})
}
// Test valid registration
t.Run("valid_registration", func(t *testing.T) {
customNetwork := "custom-network"
RegisterNetwork(customNetwork, mockListener)
if _, exists := networkTypes[customNetwork]; !exists {
t.Error("Custom network should be registered")
}
})
// Test duplicate registration should panic
t.Run("duplicate_registration", func(t *testing.T) {
customNetwork := "another-custom"
RegisterNetwork(customNetwork, mockListener)
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic for duplicate registration")
}
}()
RegisterNetwork(customNetwork, mockListener)
})
}
func TestListenerUsage_EdgeCases(t *testing.T) {
// Test ListenerUsage function with various inputs
tests := []struct {
name string
network string
addr string
expected int
}{
{
name: "non-existent listener",
network: "tcp",
addr: "localhost:9999",
expected: 0,
},
{
name: "empty network and address",
network: "",
addr: "",
expected: 0,
},
{
name: "unix socket",
network: "unix",
addr: "/tmp/non-existent.sock",
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
usage := ListenerUsage(test.network, test.addr)
if usage != test.expected {
t.Errorf("Expected usage %d, got %d", test.expected, usage)
}
})
}
}
func TestNetworkAddress_Port_Formatting(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
expected string
}{
{
name: "single port",
addr: NetworkAddress{StartPort: 80, EndPort: 80},
expected: "80",
},
{
name: "port range",
addr: NetworkAddress{StartPort: 8080, EndPort: 8090},
expected: "8080-8090",
},
{
name: "zero ports",
addr: NetworkAddress{StartPort: 0, EndPort: 0},
expected: "0",
},
{
name: "large ports",
addr: NetworkAddress{StartPort: 65534, EndPort: 65535},
expected: "65534-65535",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.addr.port()
if result != test.expected {
t.Errorf("Expected '%s', got '%s'", test.expected, result)
}
})
}
}
func TestNetworkAddress_JoinHostPort_SpecialNetworks(t *testing.T) {
tests := []struct {
name string
addr NetworkAddress
offset uint
expected string
}{
{
name: "unix socket ignores offset",
addr: NetworkAddress{
Network: "unix",
Host: "/tmp/socket",
},
offset: 100,
expected: "/tmp/socket",
},
{
name: "fd network ignores offset",
addr: NetworkAddress{
Network: "fd",
Host: "3",
},
offset: 50,
expected: "3",
},
{
name: "tcp with offset",
addr: NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8000,
},
offset: 10,
expected: "localhost:8010",
},
{
name: "ipv6 with offset",
addr: NetworkAddress{
Network: "tcp",
Host: "::1",
StartPort: 8000,
},
offset: 5,
expected: "[::1]:8005",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.addr.JoinHostPort(test.offset)
if result != test.expected {
t.Errorf("Expected '%s', got '%s'", test.expected, result)
}
})
}
}
// Helper function for string containment check
func containsString(haystack, needle string) bool {
return len(haystack) >= len(needle) &&
(needle == "" || haystack == needle ||
strings.Contains(haystack, needle))
}
func TestListenerKey_Generation(t *testing.T) {
tests := []struct {
network string
addr string
expected string
}{
{
network: "tcp",
addr: "localhost:8080",
expected: "tcp/localhost:8080",
},
{
network: "unix",
addr: "/tmp/socket",
expected: "unix//tmp/socket",
},
{
network: "",
addr: "localhost:8080",
expected: "/localhost:8080",
},
{
network: "tcp",
addr: "",
expected: "tcp/",
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%s_%s", test.network, test.addr), func(t *testing.T) {
result := listenerKey(test.network, test.addr)
if result != test.expected {
t.Errorf("Expected '%s', got '%s'", test.expected, result)
}
})
}
}
func TestNetworkAddress_ConcurrentAccess(t *testing.T) {
// Test that NetworkAddress methods are safe for concurrent read access
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8080,
EndPort: 8090,
}
const numGoroutines = 50
var wg sync.WaitGroup
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
// Call various methods concurrently
_ = addr.String()
_ = addr.PortRangeSize()
_ = addr.IsUnixNetwork()
_ = addr.IsFdNetwork()
_ = addr.isLoopback()
_ = addr.isWildcardInterface()
_ = addr.port()
_ = addr.JoinHostPort(uint(id % 10))
_ = addr.At(uint(id % 11))
// Expand creates new slice, should be safe
expanded := addr.Expand()
if len(expanded) == 0 {
t.Errorf("Goroutine %d: Expected non-empty expansion", id)
}
}(i)
}
wg.Wait()
}
func TestNetworkAddress_IPv6_Zone_Handling(t *testing.T) {
// Test IPv6 addresses with zone identifiers
input := "tcp/[fe80::1%eth0]:8080"
addr, err := ParseNetworkAddress(input)
if err != nil {
t.Fatalf("Failed to parse IPv6 with zone: %v", err)
}
if addr.Network != "tcp" {
t.Errorf("Expected network 'tcp', got '%s'", addr.Network)
}
if addr.Host != "fe80::1%eth0" {
t.Errorf("Expected host 'fe80::1%%eth0', got '%s'", addr.Host)
}
if addr.StartPort != 8080 {
t.Errorf("Expected port 8080, got %d", addr.StartPort)
}
// Test string representation round-trip
str := addr.String()
parsed, err := ParseNetworkAddress(str)
if err != nil {
t.Fatalf("Failed to parse string representation: %v", err)
}
if parsed.Host != addr.Host {
t.Errorf("Round-trip failed: expected host '%s', got '%s'", addr.Host, parsed.Host)
}
}
func BenchmarkParseNetworkAddress(b *testing.B) {
inputs := []string{
"localhost:8080",
"tcp/localhost:8080-8090",
"unix//tmp/socket",
"[::1]:443",
"udp/:53",
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
input := inputs[i%len(inputs)]
ParseNetworkAddress(input)
}
}
func BenchmarkNetworkAddress_String(b *testing.B) {
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8080,
EndPort: 8090,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
addr.String()
}
}
func BenchmarkNetworkAddress_Expand(b *testing.B) {
addr := NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 8000,
EndPort: 8100, // 101 addresses
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
addr.Expand()
}
}