diff --git a/listeners.go b/listeners.go index 8a862bacf..b673c86e1 100644 --- a/listeners.go +++ b/listeners.go @@ -38,6 +38,10 @@ import ( "github.com/caddyserver/caddy/v2/internal" ) +// listenFdsStart is the first file descriptor number for systemd socket activation. +// File descriptors 0, 1, 2 are reserved for stdin, stdout, stderr. +const listenFdsStart = 3 + // NetworkAddress represents one or more network addresses. // It contains the individual components for a parsed network // address of the form accepted by ParseNetworkAddress(). @@ -305,6 +309,64 @@ func IsFdNetwork(netw string) bool { return strings.HasPrefix(netw, "fd") } +// getFdByName returns the file descriptor number for the given +// socket name from systemd's LISTEN_FDNAMES environment variable. +// Socket names are provided by systemd via socket activation. +// +// The name can optionally include an index to handle multiple sockets +// with the same name: "web:0" for first, "web:1" for second, etc. +// If no index is specified, defaults to index 0 (first occurrence). +func getFdByName(nameWithIndex string) (int, error) { + if nameWithIndex == "" { + return 0, fmt.Errorf("socket name cannot be empty") + } + + fdNamesStr := os.Getenv("LISTEN_FDNAMES") + if fdNamesStr == "" { + return 0, fmt.Errorf("LISTEN_FDNAMES environment variable not set") + } + + // Parse name and optional index + parts := strings.Split(nameWithIndex, ":") + if len(parts) > 2 { + return 0, fmt.Errorf("invalid socket name format '%s': too many colons", nameWithIndex) + } + + name := parts[0] + targetIndex := 0 + + if len(parts) > 1 { + var err error + targetIndex, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, fmt.Errorf("invalid socket index '%s': %v", parts[1], err) + } + if targetIndex < 0 { + return 0, fmt.Errorf("socket index cannot be negative: %d", targetIndex) + } + } + + // Parse the socket names + names := strings.Split(fdNamesStr, ":") + + // Find the Nth occurrence of the requested name + matchCount := 0 + for i, fdName := range names { + if fdName == name { + if matchCount == targetIndex { + return listenFdsStart + i, nil + } + matchCount++ + } + } + + if matchCount == 0 { + return 0, fmt.Errorf("socket name '%s' not found in LISTEN_FDNAMES", name) + } + + return 0, fmt.Errorf("socket name '%s' found %d times, but index %d requested", name, matchCount, targetIndex) +} + // ParseNetworkAddress parses addr into its individual // components. The input string is expected to be of // the form "network/host:port-range" where any part is @@ -336,9 +398,27 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui }, err } if IsFdNetwork(network) { + fdAddr := host + + // Handle named socket activation (fdname/name, fdgramname/name) + if strings.HasPrefix(network, "fdname") || strings.HasPrefix(network, "fdgramname") { + fdNum, err := getFdByName(host) + if err != nil { + return NetworkAddress{}, fmt.Errorf("named socket activation: %v", err) + } + fdAddr = strconv.Itoa(fdNum) + + // Normalize network to standard fd/fdgram + if strings.HasPrefix(network, "fdname") { + network = "fd" + } else { + network = "fdgram" + } + } + return NetworkAddress{ Network: network, - Host: host, + Host: fdAddr, }, nil } var start, end uint64 diff --git a/listeners_test.go b/listeners_test.go index a4cadd3aa..c2cc255f2 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -15,6 +15,7 @@ package caddy import ( + "os" "reflect" "testing" @@ -652,3 +653,286 @@ func TestSplitUnixSocketPermissionsBits(t *testing.T) { } } } + +// TestGetFdByName tests the getFdByName function for systemd socket activation. +func TestGetFdByName(t *testing.T) { + // Save original environment + originalFdNames := os.Getenv("LISTEN_FDNAMES") + + // Restore environment after test + defer func() { + if originalFdNames != "" { + os.Setenv("LISTEN_FDNAMES", originalFdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + }() + + tests := []struct { + name string + fdNames string + socketName string + expectedFd int + expectError bool + }{ + { + name: "simple http socket", + fdNames: "http", + socketName: "http", + expectedFd: 3, + }, + { + name: "multiple different sockets - first", + fdNames: "http:https:dns", + socketName: "http", + expectedFd: 3, + }, + { + name: "multiple different sockets - second", + fdNames: "http:https:dns", + socketName: "https", + expectedFd: 4, + }, + { + name: "multiple different sockets - third", + fdNames: "http:https:dns", + socketName: "dns", + expectedFd: 5, + }, + { + name: "duplicate names - first occurrence (no index)", + fdNames: "web:web:api", + socketName: "web", + expectedFd: 3, + }, + { + name: "duplicate names - first occurrence (explicit index 0)", + fdNames: "web:web:api", + socketName: "web:0", + expectedFd: 3, + }, + { + name: "duplicate names - second occurrence (index 1)", + fdNames: "web:web:api", + socketName: "web:1", + expectedFd: 4, + }, + { + name: "complex duplicates - first api", + fdNames: "web:api:web:api:dns", + socketName: "api:0", + expectedFd: 4, + }, + { + name: "complex duplicates - second api", + fdNames: "web:api:web:api:dns", + socketName: "api:1", + expectedFd: 6, + }, + { + name: "complex duplicates - first web", + fdNames: "web:api:web:api:dns", + socketName: "web:0", + expectedFd: 3, + }, + { + name: "complex duplicates - second web", + fdNames: "web:api:web:api:dns", + socketName: "web:1", + expectedFd: 5, + }, + { + name: "socket not found", + fdNames: "http:https", + socketName: "missing", + expectError: true, + }, + { + name: "empty socket name", + fdNames: "http", + socketName: "", + expectError: true, + }, + { + name: "missing LISTEN_FDNAMES", + fdNames: "", + socketName: "http", + expectError: true, + }, + { + name: "index out of range", + fdNames: "web:web", + socketName: "web:2", + expectError: true, + }, + { + name: "negative index", + fdNames: "web", + socketName: "web:-1", + expectError: true, + }, + { + name: "invalid index format", + fdNames: "web", + socketName: "web:abc", + expectError: true, + }, + { + name: "too many colons", + fdNames: "web", + socketName: "web:0:extra", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set up environment + if tc.fdNames != "" { + os.Setenv("LISTEN_FDNAMES", tc.fdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + + // Test the function + fd, err := getFdByName(tc.socketName) + + if tc.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + if fd != tc.expectedFd { + t.Errorf("Expected FD %d but got %d", tc.expectedFd, fd) + } + } + }) + } +} + +// TestParseNetworkAddressFdName tests parsing of fdname and fdgramname addresses. +func TestParseNetworkAddressFdName(t *testing.T) { + // Save and restore environment + originalFdNames := os.Getenv("LISTEN_FDNAMES") + defer func() { + if originalFdNames != "" { + os.Setenv("LISTEN_FDNAMES", originalFdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + }() + + // Set up test environment + os.Setenv("LISTEN_FDNAMES", "http:https:dns") + + tests := []struct { + input string + expectAddr NetworkAddress + expectErr bool + }{ + { + input: "fdname/http", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "3", + }, + }, + { + input: "fdname/https", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "4", + }, + }, + { + input: "fdname/dns", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "5", + }, + }, + { + input: "fdname/http:0", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "3", + }, + }, + { + input: "fdname/https:0", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "4", + }, + }, + { + input: "fdgramname/http", + expectAddr: NetworkAddress{ + Network: "fdgram", + Host: "3", + }, + }, + { + input: "fdgramname/https", + expectAddr: NetworkAddress{ + Network: "fdgram", + Host: "4", + }, + }, + { + input: "fdgramname/http:0", + expectAddr: NetworkAddress{ + Network: "fdgram", + Host: "3", + }, + }, + { + input: "fdname/nonexistent", + expectErr: true, + }, + { + input: "fdgramname/nonexistent", + expectErr: true, + }, + { + input: "fdname/http:99", + expectErr: true, + }, + { + input: "fdname/invalid:abc", + expectErr: true, + }, + // Test that old fd/N syntax still works + { + input: "fd/7", + expectAddr: NetworkAddress{ + Network: "fd", + Host: "7", + }, + }, + { + input: "fdgram/8", + expectAddr: NetworkAddress{ + Network: "fdgram", + Host: "8", + }, + }, + } + + for i, tc := range tests { + actualAddr, err := ParseNetworkAddress(tc.input) + + if tc.expectErr && err == nil { + t.Errorf("Test %d (%s): Expected error but got none", i, tc.input) + } + if !tc.expectErr && err != nil { + t.Errorf("Test %d (%s): Expected no error but got: %v", i, tc.input, err) + } + if !tc.expectErr && !reflect.DeepEqual(tc.expectAddr, actualAddr) { + t.Errorf("Test %d (%s): Expected %+v but got %+v", i, tc.input, tc.expectAddr, actualAddr) + } + } +}