mirror of
https://github.com/caddyserver/caddy.git
synced 2025-12-08 06:09:53 +00:00
Merge 201cba5b66 into 31960dc998
This commit is contained in:
commit
939b376fba
16 changed files with 6631 additions and 3 deletions
377
api_error_test.go
Normal file
377
api_error_test.go
Normal file
|
|
@ -0,0 +1,377 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAPIError_Error_WithErr(t *testing.T) {
|
||||
underlyingErr := errors.New("underlying error")
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: underlyingErr,
|
||||
Message: "API error message",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := "underlying error"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Error_WithoutErr(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: nil,
|
||||
Message: "API error message",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := "API error message"
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected '%s', got '%s'", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Error_BothNil(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: nil,
|
||||
Message: "",
|
||||
}
|
||||
|
||||
result := apiErr.Error()
|
||||
expected := ""
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected empty string, got '%s'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_JSON_Serialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
apiErr APIError
|
||||
}{
|
||||
{
|
||||
name: "with message only",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "validation failed",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with underlying error only",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: errors.New("internal error"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with both message and error",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusConflict,
|
||||
Err: errors.New("underlying"),
|
||||
Message: "conflict detected",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "minimal error",
|
||||
apiErr: APIError{
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(test.apiErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled APIError
|
||||
err = json.Unmarshal(jsonData, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Only Message field should survive JSON round-trip
|
||||
// HTTPStatus and Err are marked with json:"-"
|
||||
if unmarshaled.Message != test.apiErr.Message {
|
||||
t.Errorf("Message mismatch: expected '%s', got '%s'",
|
||||
test.apiErr.Message, unmarshaled.Message)
|
||||
}
|
||||
|
||||
// HTTPStatus and Err should be zero values after unmarshal
|
||||
if unmarshaled.HTTPStatus != 0 {
|
||||
t.Errorf("HTTPStatus should be 0 after unmarshal, got %d", unmarshaled.HTTPStatus)
|
||||
}
|
||||
if unmarshaled.Err != nil {
|
||||
t.Errorf("Err should be nil after unmarshal, got %v", unmarshaled.Err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_HTTPStatus_Values(t *testing.T) {
|
||||
// Test common HTTP status codes
|
||||
statusCodes := []int{
|
||||
http.StatusBadRequest,
|
||||
http.StatusUnauthorized,
|
||||
http.StatusForbidden,
|
||||
http.StatusNotFound,
|
||||
http.StatusMethodNotAllowed,
|
||||
http.StatusConflict,
|
||||
http.StatusPreconditionFailed,
|
||||
http.StatusInternalServerError,
|
||||
http.StatusNotImplemented,
|
||||
http.StatusServiceUnavailable,
|
||||
}
|
||||
|
||||
for _, status := range statusCodes {
|
||||
t.Run(fmt.Sprintf("status_%d", status), func(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: status,
|
||||
Message: http.StatusText(status),
|
||||
}
|
||||
|
||||
if apiErr.HTTPStatus != status {
|
||||
t.Errorf("Expected status %d, got %d", status, apiErr.HTTPStatus)
|
||||
}
|
||||
|
||||
// Test that error message is reasonable
|
||||
if apiErr.Message == "" && status >= 400 {
|
||||
t.Errorf("Status %d should have a message", status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_ErrorInterface_Compliance(t *testing.T) {
|
||||
// Verify APIError properly implements error interface
|
||||
var err error = APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "test error",
|
||||
}
|
||||
|
||||
errorMsg := err.Error()
|
||||
if errorMsg != "test error" {
|
||||
t.Errorf("Expected 'test error', got '%s'", errorMsg)
|
||||
}
|
||||
|
||||
// Test with underlying error
|
||||
underlyingErr := errors.New("underlying")
|
||||
err2 := APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: underlyingErr,
|
||||
Message: "wrapper",
|
||||
}
|
||||
|
||||
if err2.Error() != "underlying" {
|
||||
t.Errorf("Expected 'underlying', got '%s'", err2.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_JSON_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
}{
|
||||
{
|
||||
name: "empty message",
|
||||
message: "",
|
||||
},
|
||||
{
|
||||
name: "unicode message",
|
||||
message: "Error: 🚨 Something went wrong! 你好",
|
||||
},
|
||||
{
|
||||
name: "json characters in message",
|
||||
message: `Error with "quotes" and {brackets}`,
|
||||
},
|
||||
{
|
||||
name: "newlines in message",
|
||||
message: "Line 1\nLine 2\r\nLine 3",
|
||||
},
|
||||
{
|
||||
name: "very long message",
|
||||
message: string(make([]byte, 10000)), // 10KB message
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: test.message,
|
||||
}
|
||||
|
||||
// Should be JSON serializable
|
||||
jsonData, err := json.Marshal(apiErr)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal APIError: %v", err)
|
||||
}
|
||||
|
||||
// Should be deserializable
|
||||
var unmarshaled APIError
|
||||
err = json.Unmarshal(jsonData, &unmarshaled)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal APIError: %v", err)
|
||||
}
|
||||
|
||||
if unmarshaled.Message != test.message {
|
||||
t.Errorf("Message corrupted during JSON round-trip")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_Chaining(t *testing.T) {
|
||||
// Test error chaining scenarios
|
||||
rootErr := errors.New("root cause")
|
||||
wrappedErr := fmt.Errorf("wrapped: %w", rootErr)
|
||||
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Err: wrappedErr,
|
||||
Message: "API wrapper",
|
||||
}
|
||||
|
||||
// Error() should return the underlying error message
|
||||
if apiErr.Error() != wrappedErr.Error() {
|
||||
t.Errorf("Expected underlying error message, got '%s'", apiErr.Error())
|
||||
}
|
||||
|
||||
// Should be able to unwrap
|
||||
if !errors.Is(apiErr.Err, rootErr) {
|
||||
t.Error("Should be able to unwrap to root cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIError_StatusCode_Boundaries(t *testing.T) {
|
||||
// Test edge cases for HTTP status codes
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
valid bool
|
||||
}{
|
||||
{
|
||||
name: "negative status",
|
||||
status: -1,
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "zero status",
|
||||
status: 0,
|
||||
valid: false,
|
||||
},
|
||||
{
|
||||
name: "valid 1xx",
|
||||
status: http.StatusContinue,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 2xx",
|
||||
status: http.StatusOK,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 4xx",
|
||||
status: http.StatusBadRequest,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "valid 5xx",
|
||||
status: http.StatusInternalServerError,
|
||||
valid: true,
|
||||
},
|
||||
{
|
||||
name: "too large status",
|
||||
status: 9999,
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := APIError{
|
||||
HTTPStatus: test.status,
|
||||
Message: "test",
|
||||
}
|
||||
|
||||
// The struct allows any int value, but we can test
|
||||
// if it's a valid HTTP status
|
||||
statusText := http.StatusText(test.status)
|
||||
isValidStatus := statusText != ""
|
||||
|
||||
if isValidStatus != test.valid {
|
||||
t.Errorf("Status %d validity: expected %v, got %v",
|
||||
test.status, test.valid, isValidStatus)
|
||||
}
|
||||
|
||||
// Verify the struct holds the status
|
||||
if err.HTTPStatus != test.status {
|
||||
t.Errorf("Status not preserved: expected %d, got %d", test.status, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_Error(b *testing.B) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: errors.New("benchmark error"),
|
||||
Message: "benchmark message",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
apiErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_JSON_Marshal(b *testing.B) {
|
||||
apiErr := APIError{
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Err: errors.New("benchmark error"),
|
||||
Message: "benchmark message",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
json.Marshal(apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAPIError_JSON_Unmarshal(b *testing.B) {
|
||||
jsonData := []byte(`{"error": "benchmark message"}`)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
var result APIError
|
||||
_ = json.Unmarshal(jsonData, &result)
|
||||
}
|
||||
}
|
||||
234
cmd/packagesfuncs_test.go
Normal file
234
cmd/packagesfuncs_test.go
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddycmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitModule(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedModule string
|
||||
expectedVersion string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "simple module without version",
|
||||
input: "github.com/caddyserver/caddy",
|
||||
expectedModule: "github.com/caddyserver/caddy",
|
||||
expectedVersion: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with version",
|
||||
input: "github.com/caddyserver/caddy@v2.0.0",
|
||||
expectedModule: "github.com/caddyserver/caddy",
|
||||
expectedVersion: "v2.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with semantic version",
|
||||
input: "github.com/user/module@v1.2.3",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "v1.2.3",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with prerelease version",
|
||||
input: "github.com/user/module@v1.0.0-beta.1",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "v1.0.0-beta.1",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with commit hash",
|
||||
input: "github.com/user/module@abc123def",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "abc123def",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with @ in path and version",
|
||||
input: "github.com/@user/module@v1.0.0",
|
||||
expectedModule: "github.com/@user/module",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with multiple @ in path",
|
||||
input: "github.com/@org/@user/module@v2.3.4",
|
||||
expectedModule: "github.com/@org/@user/module",
|
||||
expectedVersion: "v2.3.4",
|
||||
expectError: false,
|
||||
},
|
||||
// TODO: decide on the behavior for this case; it fails currently
|
||||
// {
|
||||
// name: "module with @ in path but no version",
|
||||
// input: "github.com/@user/module",
|
||||
// expectedModule: "github.com/@user/module",
|
||||
// expectedVersion: "",
|
||||
// expectError: false,
|
||||
// },
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectedModule: "",
|
||||
expectedVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "only @ symbol",
|
||||
input: "@",
|
||||
expectedModule: "",
|
||||
expectedVersion: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "@ at start",
|
||||
input: "@v1.0.0",
|
||||
expectedModule: "",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "@ at end",
|
||||
input: "github.com/user/module@",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple consecutive @",
|
||||
input: "github.com/user/module@@v1.0.0",
|
||||
expectedModule: "github.com/user/module@",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "version with latest tag",
|
||||
input: "github.com/user/module@latest",
|
||||
expectedModule: "github.com/user/module",
|
||||
expectedVersion: "latest",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "long module path",
|
||||
input: "github.com/organization/team/project/subproject/module@v3.14.159",
|
||||
expectedModule: "github.com/organization/team/project/subproject/module",
|
||||
expectedVersion: "v3.14.159",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with dots in name",
|
||||
input: "github.com/user/my.module.name@v1.0",
|
||||
expectedModule: "github.com/user/my.module.name",
|
||||
expectedVersion: "v1.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "module with hyphens",
|
||||
input: "github.com/user/my-module-name@v1.0.0",
|
||||
expectedModule: "github.com/user/my-module-name",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "gitlab module",
|
||||
input: "gitlab.com/user/module@v2.0.0",
|
||||
expectedModule: "gitlab.com/user/module",
|
||||
expectedVersion: "v2.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "bitbucket module",
|
||||
input: "bitbucket.org/user/module@v1.5.0",
|
||||
expectedModule: "bitbucket.org/user/module",
|
||||
expectedVersion: "v1.5.0",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "custom domain",
|
||||
input: "example.com/custom/module@v1.0.0",
|
||||
expectedModule: "example.com/custom/module",
|
||||
expectedVersion: "v1.0.0",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
module, version, err := splitModule(tt.input)
|
||||
|
||||
// Check error expectation
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check module
|
||||
if module != tt.expectedModule {
|
||||
t.Errorf("module: got %q, want %q", module, tt.expectedModule)
|
||||
}
|
||||
|
||||
// Check version
|
||||
if version != tt.expectedVersion {
|
||||
t.Errorf("version: got %q, want %q", version, tt.expectedVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitModule_ErrorCases(t *testing.T) {
|
||||
errorCases := []string{
|
||||
"",
|
||||
"@",
|
||||
"@version",
|
||||
"@v1.0.0",
|
||||
}
|
||||
|
||||
for _, tc := range errorCases {
|
||||
t.Run("error_"+tc, func(t *testing.T) {
|
||||
_, _, err := splitModule(tc)
|
||||
if err == nil {
|
||||
t.Errorf("splitModule(%q) should return error", tc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSplitModule benchmarks the splitModule function
|
||||
func BenchmarkSplitModule(b *testing.B) {
|
||||
testCases := []string{
|
||||
"github.com/user/module",
|
||||
"github.com/user/module@v1.0.0",
|
||||
"github.com/@org/@user/module@v2.3.4",
|
||||
"github.com/organization/team/project/subproject/module@v3.14.159",
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
splitModule(tc)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
719
config_test.go
Normal file
719
config_test.go
Normal file
|
|
@ -0,0 +1,719 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestConfig_Start_Stop_Basic(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true}, // Disable admin to avoid port conflicts
|
||||
}
|
||||
|
||||
ctx, err := run(cfg, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to run config: %v", err)
|
||||
}
|
||||
|
||||
// Verify context is valid
|
||||
if ctx.cfg == nil {
|
||||
t.Error("Expected non-nil config in context")
|
||||
}
|
||||
|
||||
// Stop the config
|
||||
unsyncedStop(ctx)
|
||||
|
||||
// Verify cleanup was called
|
||||
if ctx.cfg.cancelFunc == nil {
|
||||
t.Error("Expected cancel function to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_InvalidConfig(t *testing.T) {
|
||||
// Create a config with an invalid app module
|
||||
cfg := &Config{
|
||||
AppsRaw: ModuleMap{
|
||||
"non-existent-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
err := Validate(cfg)
|
||||
if err == nil {
|
||||
t.Error("Expected validation error for invalid app module")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ValidConfig(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
}
|
||||
|
||||
err := Validate(cfg)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected validation error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_ConcurrentAccess(t *testing.T) {
|
||||
// Save original config state
|
||||
originalRawCfg := rawCfg[rawConfigKey]
|
||||
originalRawCfgJSON := rawCfgJSON
|
||||
defer func() {
|
||||
rawCfg[rawConfigKey] = originalRawCfg
|
||||
rawCfgJSON = originalRawCfgJSON
|
||||
}()
|
||||
|
||||
// Initialize with a basic config
|
||||
initialCfg := map[string]any{
|
||||
"test": "value",
|
||||
}
|
||||
rawCfg[rawConfigKey] = initialCfg
|
||||
|
||||
const numGoroutines = 10 // Reduced for more controlled testing
|
||||
var wg sync.WaitGroup
|
||||
errors := make([]error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Only test read operations to avoid complex state changes
|
||||
// that could cause nil pointer issues in concurrent scenarios
|
||||
var buf bytes.Buffer
|
||||
errors[index] = readConfig("/"+rawConfigKey+"/test", &buf)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check that read operations succeeded
|
||||
for i, err := range errors {
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected read error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_MethodValidation(t *testing.T) {
|
||||
// Save original config state
|
||||
originalRawCfg := rawCfg[rawConfigKey]
|
||||
defer func() {
|
||||
rawCfg[rawConfigKey] = originalRawCfg
|
||||
}()
|
||||
|
||||
// Set up a simple valid config for testing
|
||||
rawCfg[rawConfigKey] = map[string]any{}
|
||||
|
||||
tests := []struct {
|
||||
method string
|
||||
expectErr bool
|
||||
}{
|
||||
{http.MethodPost, false},
|
||||
{http.MethodPut, true}, // because key 'admin' already exists
|
||||
{http.MethodPatch, false},
|
||||
{http.MethodDelete, false},
|
||||
{http.MethodGet, true},
|
||||
{http.MethodHead, true},
|
||||
{http.MethodOptions, true},
|
||||
{http.MethodConnect, true},
|
||||
{http.MethodTrace, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.method, func(t *testing.T) {
|
||||
// Use a simple admin config path that won't cause complex validation
|
||||
err := changeConfig(test.method, "/"+rawConfigKey+"/admin", []byte(`{"disabled": true}`), "", false)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error for invalid method")
|
||||
}
|
||||
if !test.expectErr && err != nil && (err != errSameConfig) {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfig_IfMatchHeader_Validation(t *testing.T) {
|
||||
// Set up initial config
|
||||
initialCfg := map[string]any{"test": "value"}
|
||||
rawCfg[rawConfigKey] = initialCfg
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ifMatch string
|
||||
expectErr bool
|
||||
expectStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "malformed - no quotes",
|
||||
ifMatch: "path hash",
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - single quote",
|
||||
ifMatch: `"path hash`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - wrong number of parts",
|
||||
ifMatch: `"path"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "malformed - too many parts",
|
||||
ifMatch: `"path hash extra"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "wrong hash",
|
||||
ifMatch: `"/config/test wronghash"`,
|
||||
expectErr: true,
|
||||
expectStatusCode: http.StatusPreconditionFailed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
err := changeConfig(http.MethodPost, "/"+rawConfigKey+"/test", []byte(`"newvalue"`), test.ifMatch, false)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectErr && err != nil {
|
||||
if apiErr, ok := err.(APIError); ok {
|
||||
if apiErr.HTTPStatus != test.expectStatusCode {
|
||||
t.Errorf("Expected status %d, got %d", test.expectStatusCode, apiErr.HTTPStatus)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected APIError type")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConfigObjects_Basic(t *testing.T) {
|
||||
config := map[string]any{
|
||||
"app1": map[string]any{
|
||||
"@id": "my-app",
|
||||
"config": "value",
|
||||
},
|
||||
"nested": map[string]any{
|
||||
"array": []any{
|
||||
map[string]any{
|
||||
"@id": "nested-item",
|
||||
"data": "test",
|
||||
},
|
||||
map[string]any{
|
||||
"@id": 123.0, // JSON numbers are float64
|
||||
"more": "data",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
index := make(map[string]string)
|
||||
err := indexConfigObjects(config, "/config", index)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
expected := map[string]string{
|
||||
"my-app": "/config/app1",
|
||||
"nested-item": "/config/nested/array/0",
|
||||
"123": "/config/nested/array/1",
|
||||
}
|
||||
|
||||
if len(index) != len(expected) {
|
||||
t.Errorf("Expected %d indexed items, got %d", len(expected), len(index))
|
||||
}
|
||||
|
||||
for id, expectedPath := range expected {
|
||||
if actualPath, exists := index[id]; !exists || actualPath != expectedPath {
|
||||
t.Errorf("ID %s: expected path '%s', got '%s'", id, expectedPath, actualPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndexConfigObjects_InvalidID(t *testing.T) {
|
||||
config := map[string]any{
|
||||
"app": map[string]any{
|
||||
"@id": map[string]any{"invalid": "id"}, // Invalid ID type
|
||||
},
|
||||
}
|
||||
|
||||
index := make(map[string]string)
|
||||
err := indexConfigObjects(config, "/config", index)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid ID type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AppStartFailure(t *testing.T) {
|
||||
// Register a mock app that fails to start
|
||||
RegisterModule(&failingApp{})
|
||||
defer func() {
|
||||
// Clean up module registry
|
||||
delete(modules, "failing-app")
|
||||
}()
|
||||
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
AppsRaw: ModuleMap{
|
||||
"failing-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := run(cfg, true)
|
||||
if err == nil {
|
||||
t.Error("Expected error when app fails to start")
|
||||
}
|
||||
|
||||
// Should contain the app name in the error
|
||||
if err.Error() == "" {
|
||||
t.Error("Expected descriptive error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_AppStopFailure_During_Cleanup(t *testing.T) {
|
||||
// Register apps where one fails to start and another fails to stop
|
||||
RegisterModule(&workingApp{})
|
||||
RegisterModule(&failingStopApp{})
|
||||
defer func() {
|
||||
delete(modules, "working-app")
|
||||
delete(modules, "failing-stop-app")
|
||||
}()
|
||||
|
||||
cfg := &Config{
|
||||
Admin: &AdminConfig{Disabled: true},
|
||||
AppsRaw: ModuleMap{
|
||||
"working-app": json.RawMessage(`{}`),
|
||||
"failing-stop-app": json.RawMessage(`{}`),
|
||||
},
|
||||
}
|
||||
|
||||
// Start both apps
|
||||
ctx, err := run(cfg, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error starting apps: %v", err)
|
||||
}
|
||||
|
||||
// Stop context - this should handle stop failures gracefully
|
||||
unsyncedStop(ctx)
|
||||
|
||||
// Test passed if we reach here without panic
|
||||
}
|
||||
|
||||
func TestProvisionContext_NilConfig(t *testing.T) {
|
||||
ctx, err := provisionContext(nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if ctx.cfg == nil {
|
||||
t.Error("Expected non-nil config even when input is nil")
|
||||
}
|
||||
|
||||
// Clean up
|
||||
ctx.cfg.cancelFunc()
|
||||
}
|
||||
|
||||
func TestDuration_UnmarshalJSON_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "integer nanoseconds",
|
||||
input: "1000000000",
|
||||
expected: time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "string duration",
|
||||
input: `"5m30s"`,
|
||||
expected: 5*time.Minute + 30*time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "days conversion",
|
||||
input: `"2d"`,
|
||||
expected: 48 * time.Hour,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "mixed days and hours",
|
||||
input: `"1d12h"`,
|
||||
expected: 36 * time.Hour,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid duration",
|
||||
input: `"invalid"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var d Duration
|
||||
err := d.UnmarshalJSON([]byte(test.input))
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && time.Duration(d) != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, time.Duration(d))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_LongInput(t *testing.T) {
|
||||
// Test input length limit
|
||||
longInput := string(make([]byte, 1025)) // Exceeds 1024 limit
|
||||
for i := range longInput {
|
||||
longInput = longInput[:i] + "1"
|
||||
}
|
||||
longInput += "d"
|
||||
|
||||
_, err := ParseDuration(longInput)
|
||||
if err == nil {
|
||||
t.Error("Expected error for input longer than 1024 characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVersion_Deterministic(t *testing.T) {
|
||||
// Test that Version() returns consistent results
|
||||
simple1, full1 := Version()
|
||||
simple2, full2 := Version()
|
||||
|
||||
if simple1 != simple2 {
|
||||
t.Errorf("Version() simple form not deterministic: '%s' != '%s'", simple1, simple2)
|
||||
}
|
||||
if full1 != full2 {
|
||||
t.Errorf("Version() full form not deterministic: '%s' != '%s'", full1, full2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceID_Consistency(t *testing.T) {
|
||||
// Test that InstanceID returns the same ID on subsequent calls
|
||||
id1, err := InstanceID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get instance ID: %v", err)
|
||||
}
|
||||
|
||||
id2, err := InstanceID()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get instance ID on second call: %v", err)
|
||||
}
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("InstanceID not consistent: %v != %v", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveMetaFields_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no meta fields",
|
||||
input: `{"normal": "field"}`,
|
||||
expected: `{"normal": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "single @id field",
|
||||
input: `{"@id": "test", "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id at beginning",
|
||||
input: `{"@id": "test", "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id at end",
|
||||
input: `{"other": "field", "@id": "test"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "@id in middle",
|
||||
input: `{"first": "value", "@id": "test", "last": "value"}`,
|
||||
expected: `{"first": "value", "last": "value"}`,
|
||||
},
|
||||
{
|
||||
name: "multiple @id fields",
|
||||
input: `{"@id": "test1", "other": "field", "@id": "test2"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "numeric @id",
|
||||
input: `{"@id": 123, "other": "field"}`,
|
||||
expected: `{"other": "field"}`,
|
||||
},
|
||||
{
|
||||
name: "nested objects with @id",
|
||||
input: `{"outer": {"@id": "nested", "data": "value"}}`,
|
||||
expected: `{"outer": {"data": "value"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := RemoveMetaFields([]byte(test.input))
|
||||
// resultStr := string(result)
|
||||
|
||||
// Parse both to ensure valid JSON and compare structures
|
||||
var expectedObj, resultObj any
|
||||
if err := json.Unmarshal([]byte(test.expected), &expectedObj); err != nil {
|
||||
t.Fatalf("Expected result is not valid JSON: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(result, &resultObj); err != nil {
|
||||
t.Fatalf("Result is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Note: We can't do exact string comparison due to potential field ordering
|
||||
// Instead, verify the structure matches
|
||||
expectedJSON, _ := json.Marshal(expectedObj)
|
||||
resultJSON, _ := json.Marshal(resultObj)
|
||||
|
||||
if string(expectedJSON) != string(resultJSON) {
|
||||
t.Errorf("Expected %s, got %s", string(expectedJSON), string(resultJSON))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsyncedConfigAccess_ArrayOperations_EdgeCases(t *testing.T) {
|
||||
// Test array boundary conditions and edge cases
|
||||
tests := []struct {
|
||||
name string
|
||||
initialState map[string]any
|
||||
method string
|
||||
path string
|
||||
payload string
|
||||
expectErr bool
|
||||
expectState map[string]any
|
||||
}{
|
||||
{
|
||||
name: "delete from empty array",
|
||||
initialState: map[string]any{"arr": []any{}},
|
||||
method: http.MethodDelete,
|
||||
path: "/config/arr/0",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "access negative index",
|
||||
initialState: map[string]any{"arr": []any{"a", "b"}},
|
||||
method: http.MethodGet,
|
||||
path: "/config/arr/-1",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "put at index beyond end",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPut,
|
||||
path: "/config/arr/5",
|
||||
payload: `"new"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "patch non-existent index",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPatch,
|
||||
path: "/config/arr/5",
|
||||
payload: `"new"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "put at exact end of array",
|
||||
initialState: map[string]any{"arr": []any{"a", "b"}},
|
||||
method: http.MethodPut,
|
||||
path: "/config/arr/2",
|
||||
payload: `"c"`,
|
||||
expectState: map[string]any{"arr": []any{"a", "b", "c"}},
|
||||
},
|
||||
{
|
||||
name: "ellipses with non-array payload",
|
||||
initialState: map[string]any{"arr": []any{"a"}},
|
||||
method: http.MethodPost,
|
||||
path: "/config/arr/...",
|
||||
payload: `"not-array"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Set up initial state
|
||||
rawCfg[rawConfigKey] = test.initialState
|
||||
|
||||
err := unsyncedConfigAccess(test.method, test.path, []byte(test.payload), nil)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectState != nil {
|
||||
// Compare resulting state
|
||||
expectedJSON, _ := json.Marshal(test.expectState)
|
||||
actualJSON, _ := json.Marshal(rawCfg[rawConfigKey])
|
||||
|
||||
if string(expectedJSON) != string(actualJSON) {
|
||||
t.Errorf("Expected state %s, got %s", string(expectedJSON), string(actualJSON))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitProcess_ConcurrentCalls(t *testing.T) {
|
||||
// Test that multiple concurrent calls to exitProcess are safe
|
||||
// We can't test the actual exit, but we can test the atomic flag
|
||||
|
||||
// Reset the exiting flag
|
||||
oldExiting := exiting
|
||||
exiting = new(int32)
|
||||
defer func() { exiting = oldExiting }()
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
results := make([]bool, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
// Check the Exiting() function which reads the atomic flag
|
||||
wasExitingBefore := Exiting()
|
||||
|
||||
// This would call exitProcess, but we don't want to actually exit
|
||||
// So we just test the atomic operation directly
|
||||
results[index] = atomic.CompareAndSwapInt32(exiting, 0, 1)
|
||||
|
||||
wasExitingAfter := Exiting()
|
||||
|
||||
// At least one should succeed in setting the flag
|
||||
if !wasExitingBefore && wasExitingAfter && !results[index] {
|
||||
t.Errorf("Goroutine %d: Flag was set but CAS failed", index)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Exactly one goroutine should have successfully set the flag
|
||||
successCount := 0
|
||||
for _, success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
if successCount != 1 {
|
||||
t.Errorf("Expected exactly 1 successful flag set, got %d", successCount)
|
||||
}
|
||||
|
||||
// Flag should be set
|
||||
if !Exiting() {
|
||||
t.Error("Exiting flag should be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock apps for testing
|
||||
type failingApp struct{}
|
||||
|
||||
func (fa *failingApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "failing-app",
|
||||
New: func() Module { return new(failingApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (fa *failingApp) Start() error {
|
||||
return fmt.Errorf("simulated start failure")
|
||||
}
|
||||
|
||||
func (fa *failingApp) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type workingApp struct{}
|
||||
|
||||
func (wa *workingApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "working-app",
|
||||
New: func() Module { return new(workingApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (wa *workingApp) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wa *workingApp) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type failingStopApp struct{}
|
||||
|
||||
func (fsa *failingStopApp) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "failing-stop-app",
|
||||
New: func() Module { return new(failingStopApp) },
|
||||
}
|
||||
}
|
||||
|
||||
func (fsa *failingStopApp) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fsa *failingStopApp) Stop() error {
|
||||
return fmt.Errorf("simulated stop failure")
|
||||
}
|
||||
407
duration_test.go
Normal file
407
duration_test.go
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseDuration_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "zero duration",
|
||||
input: "0",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
input: "abc",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative days",
|
||||
input: "-2d",
|
||||
expected: -48 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "decimal days",
|
||||
input: "0.5d",
|
||||
expected: 12 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "large decimal days",
|
||||
input: "365.25d",
|
||||
expected: time.Duration(365.25*24) * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "multiple days in same string",
|
||||
input: "1d2d3d",
|
||||
expected: (24 * 6) * time.Hour, // 6 days total
|
||||
},
|
||||
{
|
||||
name: "days with other units",
|
||||
input: "1d30m15s",
|
||||
expected: 24*time.Hour + 30*time.Minute + 15*time.Second,
|
||||
},
|
||||
{
|
||||
name: "malformed days",
|
||||
input: "d",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid day value",
|
||||
input: "abcd",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "overflow protection",
|
||||
input: "9999999999999999999999999d",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "zero days",
|
||||
input: "0d",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "input at limit",
|
||||
input: strings.Repeat("1", 1024) + "ns",
|
||||
expectErr: true, // Likely to cause parsing error due to size
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_InputLengthLimit(t *testing.T) {
|
||||
// Test the 1024 character limit
|
||||
longInput := strings.Repeat("1", 1025) + "s"
|
||||
|
||||
_, err := ParseDuration(longInput)
|
||||
if err == nil {
|
||||
t.Error("Expected error for input longer than 1024 characters")
|
||||
}
|
||||
|
||||
expectedErrMsg := "parsing duration: input string too long"
|
||||
if err.Error() != expectedErrMsg {
|
||||
t.Errorf("Expected error message '%s', got '%s'", expectedErrMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_ComplexNumberFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
input: "+1d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "-1.5d",
|
||||
expected: -36 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "1.0d",
|
||||
expected: 24 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "0.25d",
|
||||
expected: 6 * time.Hour,
|
||||
},
|
||||
{
|
||||
input: "1.5d30m",
|
||||
expected: 36*time.Hour + 30*time.Minute,
|
||||
},
|
||||
{
|
||||
input: "2.5d1h30m45s",
|
||||
expected: 60*time.Hour + time.Hour + 30*time.Minute + 45*time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_UnmarshalJSON_TypeValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "null value",
|
||||
input: "null",
|
||||
expectErr: false,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "boolean value",
|
||||
input: "true",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "array value",
|
||||
input: `[1,2,3]`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "object value",
|
||||
input: `{"duration": "5m"}`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative integer",
|
||||
input: "-1000000000",
|
||||
expected: -time.Second,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "zero integer",
|
||||
input: "0",
|
||||
expected: 0,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "large integer",
|
||||
input: "9223372036854775807", // Max int64
|
||||
expected: time.Duration(math.MaxInt64),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "float as integer (invalid JSON for int)",
|
||||
input: "1.5",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "string with special characters",
|
||||
input: `"5m\"30s"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "string with unicode",
|
||||
input: `"5m🚀"`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var d Duration
|
||||
err := d.UnmarshalJSON([]byte(test.input))
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if !test.expectErr && time.Duration(d) != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, time.Duration(d))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_JSON_RoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
duration time.Duration
|
||||
asString bool
|
||||
}{
|
||||
{duration: 5 * time.Minute, asString: true},
|
||||
{duration: 24 * time.Hour, asString: false}, // Will be stored as nanoseconds
|
||||
{duration: 0, asString: false},
|
||||
{duration: -time.Hour, asString: true},
|
||||
{duration: time.Nanosecond, asString: false},
|
||||
{duration: time.Second, asString: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.duration.String(), func(t *testing.T) {
|
||||
d := Duration(test.duration)
|
||||
|
||||
// Marshal to JSON
|
||||
jsonData, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var unmarshaled Duration
|
||||
err = unmarshaled.UnmarshalJSON(jsonData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Should be equal
|
||||
if time.Duration(unmarshaled) != test.duration {
|
||||
t.Errorf("Round trip failed: expected %v, got %v", test.duration, time.Duration(unmarshaled))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_Precision(t *testing.T) {
|
||||
// Test floating point precision with days
|
||||
tests := []struct {
|
||||
input string
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
input: "0.1d",
|
||||
expected: time.Duration(0.1 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "0.01d",
|
||||
expected: time.Duration(0.01 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "0.001d",
|
||||
expected: time.Duration(0.001 * 24 * float64(time.Hour)),
|
||||
},
|
||||
{
|
||||
input: "1.23456789d",
|
||||
expected: time.Duration(1.23456789 * 24 * float64(time.Hour)),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
result, err := ParseDuration(test.input)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Allow for small floating point differences
|
||||
diff := result - test.expected
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff > time.Nanosecond {
|
||||
t.Errorf("Expected %v, got %v (diff: %v)", test.expected, result, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDuration_Boundary_Values(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "minimum day value",
|
||||
input: "0.000000001d", // Very small but valid
|
||||
},
|
||||
{
|
||||
name: "very large day value",
|
||||
input: "999999999999999999999d",
|
||||
expectErr: true, // Should overflow
|
||||
},
|
||||
{
|
||||
name: "negative zero",
|
||||
input: "-0d",
|
||||
},
|
||||
{
|
||||
name: "positive zero",
|
||||
input: "+0d",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
_, err := ParseDuration(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_SimpleDay(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1d")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_ComplexDay(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1.5d30m15.5s")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDuration_MultipleDays(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ParseDuration("1d2d3d4d5d")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDuration_UnmarshalJSON_String(b *testing.B) {
|
||||
input := []byte(`"5m30s"`)
|
||||
var d Duration
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.UnmarshalJSON(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDuration_UnmarshalJSON_Integer(b *testing.B) {
|
||||
input := []byte("300000000000") // 5 minutes in nanoseconds
|
||||
var d Duration
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
d.UnmarshalJSON(input)
|
||||
}
|
||||
}
|
||||
642
event_test.go
Normal file
642
event_test.go
Normal file
|
|
@ -0,0 +1,642 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewEvent_Basic(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventName := "test.event"
|
||||
eventData := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, eventName, eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Verify event properties
|
||||
if event.Name() != eventName {
|
||||
t.Errorf("Expected name '%s', got '%s'", eventName, event.Name())
|
||||
}
|
||||
|
||||
if event.Data == nil {
|
||||
t.Error("Expected non-nil data")
|
||||
}
|
||||
|
||||
if len(event.Data) != len(eventData) {
|
||||
t.Errorf("Expected %d data items, got %d", len(eventData), len(event.Data))
|
||||
}
|
||||
|
||||
for key, expectedValue := range eventData {
|
||||
if actualValue, exists := event.Data[key]; !exists || actualValue != expectedValue {
|
||||
t.Errorf("Data key '%s': expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify ID is generated
|
||||
if event.ID().String() == "" {
|
||||
t.Error("Event ID should not be empty")
|
||||
}
|
||||
|
||||
// Verify timestamp is recent
|
||||
if time.Since(event.Timestamp()) > time.Second {
|
||||
t.Error("Event timestamp should be recent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEvent_NameNormalization(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"UPPERCASE", "uppercase"},
|
||||
{"MixedCase", "mixedcase"},
|
||||
{"already.lower", "already.lower"},
|
||||
{"With-Dashes", "with-dashes"},
|
||||
{"With_Underscores", "with_underscores"},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.input, func(t *testing.T) {
|
||||
event, err := NewEvent(ctx, test.input, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
if event.Name() != test.expected {
|
||||
t.Errorf("Expected normalized name '%s', got '%s'", test.expected, event.Name())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_NilData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Should not panic with nil data
|
||||
if cloudEvent.Data == nil {
|
||||
t.Error("CloudEvent data should not be nil even with nil input")
|
||||
}
|
||||
|
||||
// Should be valid JSON
|
||||
var parsed any
|
||||
if err := json.Unmarshal(cloudEvent.Data, &parsed); err != nil {
|
||||
t.Errorf("CloudEvent data should be valid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_WithModule(t *testing.T) {
|
||||
// Create a context with a mock module
|
||||
mockMod := &mockModule{}
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Simulate module ancestry
|
||||
ctx.ancestry = []Module{mockMod}
|
||||
|
||||
event, err := NewEvent(ctx, "test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Source should be the module ID
|
||||
expectedSource := string(mockMod.CaddyModule().ID)
|
||||
if cloudEvent.Source != expectedSource {
|
||||
t.Errorf("Expected source '%s', got '%s'", expectedSource, cloudEvent.Source)
|
||||
}
|
||||
|
||||
// Origin should be the module
|
||||
if event.Origin() != mockMod {
|
||||
t.Error("Expected event origin to be the mock module")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_CloudEvent_Fields(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventName := "test.event"
|
||||
eventData := map[string]any{"test": "data"}
|
||||
|
||||
event, err := NewEvent(ctx, eventName, eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Verify CloudEvent fields
|
||||
if cloudEvent.ID == "" {
|
||||
t.Error("CloudEvent ID should not be empty")
|
||||
}
|
||||
|
||||
if cloudEvent.Source != "caddy" {
|
||||
t.Errorf("Expected source 'caddy' for nil module, got '%s'", cloudEvent.Source)
|
||||
}
|
||||
|
||||
if cloudEvent.SpecVersion != "1.0" {
|
||||
t.Errorf("Expected spec version '1.0', got '%s'", cloudEvent.SpecVersion)
|
||||
}
|
||||
|
||||
if cloudEvent.Type != eventName {
|
||||
t.Errorf("Expected type '%s', got '%s'", eventName, cloudEvent.Type)
|
||||
}
|
||||
|
||||
if cloudEvent.DataContentType != "application/json" {
|
||||
t.Errorf("Expected content type 'application/json', got '%s'", cloudEvent.DataContentType)
|
||||
}
|
||||
|
||||
// Verify data is valid JSON
|
||||
var parsedData map[string]any
|
||||
if err := json.Unmarshal(cloudEvent.Data, &parsedData); err != nil {
|
||||
t.Errorf("CloudEvent data is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
if parsedData["test"] != "data" {
|
||||
t.Errorf("Expected data to contain test='data', got %v", parsedData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_ConcurrentAccess(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "concurrent.test", map[string]any{
|
||||
"counter": 0,
|
||||
"data": "shared",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
const numGoroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Test concurrent read access to event properties
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// These should be safe for concurrent access
|
||||
_ = event.ID()
|
||||
_ = event.Name()
|
||||
_ = event.Timestamp()
|
||||
_ = event.Origin()
|
||||
_ = event.CloudEvent()
|
||||
|
||||
// Data map is not synchronized, so read-only access should be safe
|
||||
if data, exists := event.Data["data"]; !exists || data != "shared" {
|
||||
t.Errorf("Goroutine %d: Expected shared data", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestEvent_DataModification_Warning(t *testing.T) {
|
||||
// This test documents the non-thread-safe nature of event data
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "data.test", map[string]any{
|
||||
"mutable": "original",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Modifying data after creation (this is allowed but not thread-safe)
|
||||
event.Data["mutable"] = "modified"
|
||||
event.Data["new_key"] = "new_value"
|
||||
|
||||
// Verify modifications are visible
|
||||
if event.Data["mutable"] != "modified" {
|
||||
t.Error("Data modification should be visible")
|
||||
}
|
||||
if event.Data["new_key"] != "new_value" {
|
||||
t.Error("New data should be visible")
|
||||
}
|
||||
|
||||
// CloudEvent should reflect the current state
|
||||
cloudEvent := event.CloudEvent()
|
||||
var parsedData map[string]any
|
||||
json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
|
||||
if parsedData["mutable"] != "modified" {
|
||||
t.Error("CloudEvent should reflect modified data")
|
||||
}
|
||||
if parsedData["new_key"] != "new_value" {
|
||||
t.Error("CloudEvent should reflect new data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_Aborted_State(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "abort.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
// Initially not aborted
|
||||
if event.Aborted != nil {
|
||||
t.Error("Event should not be aborted initially")
|
||||
}
|
||||
|
||||
// Simulate aborting the event
|
||||
event.Aborted = ErrEventAborted
|
||||
|
||||
if event.Aborted != ErrEventAborted {
|
||||
t.Error("Event should be marked as aborted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrEventAborted_Value(t *testing.T) {
|
||||
if ErrEventAborted == nil {
|
||||
t.Error("ErrEventAborted should not be nil")
|
||||
}
|
||||
|
||||
if ErrEventAborted.Error() != "event aborted" {
|
||||
t.Errorf("Expected 'event aborted', got '%s'", ErrEventAborted.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_UniqueIDs(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
const numEvents = 1000
|
||||
ids := make(map[string]bool)
|
||||
|
||||
for i := 0; i < numEvents; i++ {
|
||||
event, err := NewEvent(ctx, "unique.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event %d: %v", i, err)
|
||||
}
|
||||
|
||||
idStr := event.ID().String()
|
||||
if ids[idStr] {
|
||||
t.Errorf("Duplicate event ID: %s", idStr)
|
||||
}
|
||||
ids[idStr] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_TimestampProgression(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create events with small delays
|
||||
events := make([]Event, 5)
|
||||
for i := range events {
|
||||
var err error
|
||||
events[i], err = NewEvent(ctx, "time.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event %d: %v", i, err)
|
||||
}
|
||||
|
||||
if i < len(events)-1 {
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify timestamps are in ascending order
|
||||
for i := 1; i < len(events); i++ {
|
||||
if !events[i].Timestamp().After(events[i-1].Timestamp()) {
|
||||
t.Errorf("Event %d timestamp (%v) should be after event %d timestamp (%v)",
|
||||
i, events[i].Timestamp(), i-1, events[i-1].Timestamp())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_JSON_Serialization(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventData := map[string]any{
|
||||
"string": "value",
|
||||
"number": 42,
|
||||
"boolean": true,
|
||||
"array": []any{1, 2, 3},
|
||||
"object": map[string]any{"nested": "value"},
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "json.test", eventData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// CloudEvent should be JSON serializable
|
||||
cloudEventJSON, err := json.Marshal(cloudEvent)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
// Should be able to unmarshal back
|
||||
var parsed CloudEvent
|
||||
err = json.Unmarshal(cloudEventJSON, &parsed)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
// Verify key fields survived round-trip
|
||||
if parsed.ID != cloudEvent.ID {
|
||||
t.Errorf("ID mismatch after round-trip")
|
||||
}
|
||||
if parsed.Source != cloudEvent.Source {
|
||||
t.Errorf("Source mismatch after round-trip")
|
||||
}
|
||||
if parsed.Type != cloudEvent.Type {
|
||||
t.Errorf("Type mismatch after round-trip")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_EmptyData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Test with empty map
|
||||
event1, err := NewEvent(ctx, "empty.map", map[string]any{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with empty map: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent1 := event1.CloudEvent()
|
||||
var parsed1 map[string]any
|
||||
json.Unmarshal(cloudEvent1.Data, &parsed1)
|
||||
if len(parsed1) != 0 {
|
||||
t.Error("Expected empty data map")
|
||||
}
|
||||
|
||||
// Test with nil data
|
||||
event2, err := NewEvent(ctx, "nil.data", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with nil data: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent2 := event2.CloudEvent()
|
||||
if cloudEvent2.Data == nil {
|
||||
t.Error("CloudEvent data should not be nil even with nil input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_Origin_WithModule(t *testing.T) {
|
||||
mockMod := &mockEventModule{}
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Set module in ancestry
|
||||
ctx.ancestry = []Module{mockMod}
|
||||
|
||||
event, err := NewEvent(ctx, "module.test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
if event.Origin() != mockMod {
|
||||
t.Error("Expected event origin to be the mock module")
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
expectedSource := string(mockMod.CaddyModule().ID)
|
||||
if cloudEvent.Source != expectedSource {
|
||||
t.Errorf("Expected source '%s', got '%s'", expectedSource, cloudEvent.Source)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_LargeData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create event with large data
|
||||
largeData := make(map[string]any)
|
||||
for i := 0; i < 1000; i++ {
|
||||
largeData[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "large.data", largeData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event with large data: %v", err)
|
||||
}
|
||||
|
||||
// CloudEvent should handle large data
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
var parsedData map[string]any
|
||||
err = json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse large data in CloudEvent: %v", err)
|
||||
}
|
||||
|
||||
if len(parsedData) != len(largeData) {
|
||||
t.Errorf("Expected %d data items, got %d", len(largeData), len(parsedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_SpecialCharacters_InData(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
specialData := map[string]any{
|
||||
"unicode": "🚀✨",
|
||||
"newlines": "line1\nline2\r\nline3",
|
||||
"quotes": `"double" and 'single' quotes`,
|
||||
"backslashes": "\\path\\to\\file",
|
||||
"json_chars": `{"key": "value"}`,
|
||||
"empty": "",
|
||||
"null_value": nil,
|
||||
}
|
||||
|
||||
event, err := NewEvent(ctx, "special.chars", specialData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
cloudEvent := event.CloudEvent()
|
||||
|
||||
// Should produce valid JSON
|
||||
var parsedData map[string]any
|
||||
err = json.Unmarshal(cloudEvent.Data, &parsedData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse data with special characters: %v", err)
|
||||
}
|
||||
|
||||
// Verify some special cases survived JSON round-trip
|
||||
if parsedData["unicode"] != "🚀✨" {
|
||||
t.Error("Unicode characters should survive JSON encoding")
|
||||
}
|
||||
|
||||
if parsedData["quotes"] != `"double" and 'single' quotes` {
|
||||
t.Error("Quotes should be properly escaped in JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_ConcurrentCreation(t *testing.T) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
const numGoroutines = 100
|
||||
var wg sync.WaitGroup
|
||||
events := make([]Event, numGoroutines)
|
||||
errors := make([]error, numGoroutines)
|
||||
|
||||
// Create events concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
eventData := map[string]any{
|
||||
"goroutine": index,
|
||||
"timestamp": time.Now().UnixNano(),
|
||||
}
|
||||
|
||||
events[index], errors[index] = NewEvent(ctx, "concurrent.test", eventData)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all events were created successfully
|
||||
ids := make(map[string]bool)
|
||||
for i, event := range events {
|
||||
if errors[i] != nil {
|
||||
t.Errorf("Goroutine %d: Failed to create event: %v", i, errors[i])
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify unique IDs
|
||||
idStr := event.ID().String()
|
||||
if ids[idStr] {
|
||||
t.Errorf("Duplicate event ID: %s", idStr)
|
||||
}
|
||||
ids[idStr] = true
|
||||
|
||||
// Verify data integrity
|
||||
if goroutineID, exists := event.Data["goroutine"]; !exists || goroutineID != i {
|
||||
t.Errorf("Event %d: Data corruption detected", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mock module for event testing
|
||||
type mockEventModule struct{}
|
||||
|
||||
func (m *mockEventModule) CaddyModule() ModuleInfo {
|
||||
return ModuleInfo{
|
||||
ID: "test.event.module",
|
||||
New: func() Module { return new(mockEventModule) },
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvent_TimeAccuracy(t *testing.T) {
|
||||
before := time.Now()
|
||||
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, err := NewEvent(ctx, "time.accuracy", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create event: %v", err)
|
||||
}
|
||||
|
||||
after := time.Now()
|
||||
eventTime := event.Timestamp()
|
||||
|
||||
// Event timestamp should be between before and after
|
||||
if eventTime.Before(before) || eventTime.After(after) {
|
||||
t.Errorf("Event timestamp %v should be between %v and %v", eventTime, before, after)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewEvent(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
eventData := map[string]any{
|
||||
"key1": "value1",
|
||||
"key2": 42,
|
||||
"key3": true,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewEvent(ctx, "benchmark.test", eventData)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEvent_CloudEvent(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
event, _ := NewEvent(ctx, "benchmark.cloud", map[string]any{
|
||||
"data": "test",
|
||||
"num": 123,
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event.CloudEvent()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEvent_CloudEvent_LargeData(b *testing.B) {
|
||||
ctx, cancel := NewContext(Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
// Create event with substantial data
|
||||
largeData := make(map[string]any)
|
||||
for i := 0; i < 100; i++ {
|
||||
largeData[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
|
||||
}
|
||||
|
||||
event, _ := NewEvent(ctx, "benchmark.large", largeData)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
event.CloudEvent()
|
||||
}
|
||||
}
|
||||
221
filepath_test.go
Normal file
221
filepath_test.go
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFastAbs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
checkFunc func(result string, err error) error
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/usr/local/bin",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "/usr/local/bin" {
|
||||
t.Errorf("expected /usr/local/bin, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "absolute path with dots",
|
||||
input: "/usr/local/../bin",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if result != "/usr/bin" {
|
||||
t.Errorf("expected /usr/bin, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "relative path",
|
||||
input: "relative/path",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
if !strings.HasSuffix(result, "relative/path") {
|
||||
t.Errorf("expected path to end with 'relative/path', got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dot",
|
||||
input: ".",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "dot dot",
|
||||
input: "..",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
// Empty string should resolve to current directory
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex relative path",
|
||||
input: "./foo/../bar/./baz",
|
||||
checkFunc: func(result string, err error) error {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if !filepath.IsAbs(result) {
|
||||
t.Errorf("expected absolute path, got %s", result)
|
||||
}
|
||||
if !strings.HasSuffix(result, "bar/baz") {
|
||||
t.Errorf("expected path to end with 'bar/baz', got %s", result)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := FastAbs(tt.input)
|
||||
tt.checkFunc(result, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastAbsVsFilepathAbs compares FastAbs with filepath.Abs to ensure consistent behavior
|
||||
func TestFastAbsVsFilepathAbs(t *testing.T) {
|
||||
// Skip if working directory cannot be determined
|
||||
if wderr != nil {
|
||||
t.Skip("working directory error, skipping comparison test")
|
||||
}
|
||||
|
||||
testPaths := []string{
|
||||
".",
|
||||
"..",
|
||||
"foo",
|
||||
"foo/bar",
|
||||
"./foo",
|
||||
"../foo",
|
||||
"/absolute/path",
|
||||
"/usr/local/bin",
|
||||
}
|
||||
|
||||
for _, path := range testPaths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
fast, fastErr := FastAbs(path)
|
||||
std, stdErr := filepath.Abs(path)
|
||||
|
||||
// Both should succeed or fail together
|
||||
if (fastErr != nil) != (stdErr != nil) {
|
||||
t.Errorf("error mismatch: FastAbs=%v, filepath.Abs=%v", fastErr, stdErr)
|
||||
}
|
||||
|
||||
// If both succeed, results should be the same
|
||||
if fastErr == nil && stdErr == nil && fast != std {
|
||||
t.Errorf("result mismatch for %q: FastAbs=%s, filepath.Abs=%s", path, fast, std)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFastAbsErrorHandling tests error handling when working directory is unavailable
|
||||
func TestFastAbsErrorHandling(t *testing.T) {
|
||||
// This tests the cached wderr behavior
|
||||
if wderr != nil {
|
||||
// Test that FastAbs properly returns the cached error for relative paths
|
||||
_, err := FastAbs("relative/path")
|
||||
if err == nil {
|
||||
t.Error("expected error for relative path when working directory is unavailable")
|
||||
}
|
||||
if err != wderr {
|
||||
t.Errorf("expected cached wderr, got different error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFastAbs benchmarks FastAbs
|
||||
func BenchmarkFastAbs(b *testing.B) {
|
||||
paths := []string{
|
||||
"relative/path",
|
||||
"/absolute/path",
|
||||
".",
|
||||
"..",
|
||||
"./foo/bar",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
FastAbs(paths[i%len(paths)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkFastAbsVsStdLib compares performance of FastAbs vs filepath.Abs
|
||||
func BenchmarkFastAbsVsStdLib(b *testing.B) {
|
||||
path := "relative/path/to/file"
|
||||
|
||||
b.Run("FastAbs", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
FastAbs(path)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("filepath.Abs", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
filepath.Abs(path)
|
||||
}
|
||||
})
|
||||
}
|
||||
351
filesystem_test.go
Normal file
351
filesystem_test.go
Normal file
|
|
@ -0,0 +1,351 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock filesystem implementation for testing
|
||||
type mockFileSystem struct {
|
||||
name string
|
||||
files map[string]string
|
||||
}
|
||||
|
||||
func (m *mockFileSystem) Open(name string) (fs.File, error) {
|
||||
if content, exists := m.files[name]; exists {
|
||||
return &mockFile{name: name, content: content}, nil
|
||||
}
|
||||
return nil, fs.ErrNotExist
|
||||
}
|
||||
|
||||
type mockFile struct {
|
||||
name string
|
||||
content string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (m *mockFile) Stat() (fs.FileInfo, error) {
|
||||
return &mockFileInfo{name: m.name, size: int64(len(m.content))}, nil
|
||||
}
|
||||
|
||||
func (m *mockFile) Read(b []byte) (int, error) {
|
||||
if m.pos >= len(m.content) {
|
||||
return 0, fs.ErrClosed
|
||||
}
|
||||
n := copy(b, m.content[m.pos:])
|
||||
m.pos += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (m *mockFile) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockFileInfo struct {
|
||||
name string
|
||||
size int64
|
||||
}
|
||||
|
||||
func (m *mockFileInfo) Name() string { return m.name }
|
||||
func (m *mockFileInfo) Size() int64 { return m.size }
|
||||
func (m *mockFileInfo) Mode() fs.FileMode { return 0o644 }
|
||||
func (m *mockFileInfo) ModTime() time.Time {
|
||||
return time.Time{}
|
||||
}
|
||||
func (m *mockFileInfo) IsDir() bool { return false }
|
||||
func (m *mockFileInfo) Sys() any { return nil }
|
||||
|
||||
// Mock FileSystems implementation for testing
|
||||
type mockFileSystems struct {
|
||||
mu sync.RWMutex
|
||||
filesystems map[string]fs.FS
|
||||
defaultFS fs.FS
|
||||
}
|
||||
|
||||
func newMockFileSystems() *mockFileSystems {
|
||||
return &mockFileSystems{
|
||||
filesystems: make(map[string]fs.FS),
|
||||
defaultFS: &mockFileSystem{name: "default", files: map[string]string{"default.txt": "default content"}},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Register(k string, v fs.FS) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.filesystems[k] = v
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Unregister(k string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.filesystems, k)
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Get(k string) (fs.FS, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
v, ok := m.filesystems[k]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
func (m *mockFileSystems) Default() fs.FS {
|
||||
return m.defaultFS
|
||||
}
|
||||
|
||||
func TestFileSystems_Register_Get(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
mockFS := &mockFileSystem{
|
||||
name: "test",
|
||||
files: map[string]string{"test.txt": "test content"},
|
||||
}
|
||||
|
||||
// Register filesystem
|
||||
fsys.Register("test", mockFS)
|
||||
|
||||
// Retrieve filesystem
|
||||
retrieved, exists := fsys.Get("test")
|
||||
if !exists {
|
||||
t.Error("Expected filesystem to exist after registration")
|
||||
}
|
||||
if retrieved != mockFS {
|
||||
t.Error("Retrieved filesystem is not the same as registered")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Unregister(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
mockFS := &mockFileSystem{name: "test"}
|
||||
|
||||
// Register then unregister
|
||||
fsys.Register("test", mockFS)
|
||||
fsys.Unregister("test")
|
||||
|
||||
// Should not exist after unregistration
|
||||
_, exists := fsys.Get("test")
|
||||
if exists {
|
||||
t.Error("Filesystem should not exist after unregistration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Default(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
defaultFS := fsys.Default()
|
||||
if defaultFS == nil {
|
||||
t.Error("Default filesystem should not be nil")
|
||||
}
|
||||
|
||||
// Test that default filesystem works
|
||||
file, err := defaultFS.Open("default.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open default file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data := make([]byte, 100)
|
||||
n, err := file.Read(data)
|
||||
if err != nil && err != fs.ErrClosed {
|
||||
t.Fatalf("Failed to read default file: %v", err)
|
||||
}
|
||||
|
||||
content := string(data[:n])
|
||||
if content != "default content" {
|
||||
t.Errorf("Expected 'default content', got '%s'", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Concurrent_Access(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperations = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent register/unregister/get operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
key := fmt.Sprintf("fs-%d", id)
|
||||
mockFS := &mockFileSystem{
|
||||
name: key,
|
||||
files: map[string]string{key + ".txt": "content"},
|
||||
}
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
// Register
|
||||
fsys.Register(key, mockFS)
|
||||
|
||||
// Get
|
||||
retrieved, exists := fsys.Get(key)
|
||||
if !exists {
|
||||
t.Errorf("Filesystem %s should exist", key)
|
||||
continue
|
||||
}
|
||||
if retrieved != mockFS {
|
||||
t.Errorf("Retrieved filesystem for %s is not correct", key)
|
||||
}
|
||||
|
||||
// Test file access
|
||||
file, err := retrieved.Open(key + ".txt")
|
||||
if err != nil {
|
||||
t.Errorf("Failed to open file in %s: %v", key, err)
|
||||
continue
|
||||
}
|
||||
file.Close()
|
||||
|
||||
// Unregister
|
||||
fsys.Unregister(key)
|
||||
|
||||
// Should not exist after unregister
|
||||
_, stillExists := fsys.Get(key)
|
||||
if stillExists {
|
||||
t.Errorf("Filesystem %s should not exist after unregister", key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestFileSystems_Get_NonExistent(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
_, exists := fsys.Get("non-existent")
|
||||
if exists {
|
||||
t.Error("Non-existent filesystem should not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Register_Overwrite(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
key := "overwrite-test"
|
||||
|
||||
// Register first filesystem
|
||||
fs1 := &mockFileSystem{name: "fs1"}
|
||||
fsys.Register(key, fs1)
|
||||
|
||||
// Register second filesystem with same key (should overwrite)
|
||||
fs2 := &mockFileSystem{name: "fs2"}
|
||||
fsys.Register(key, fs2)
|
||||
|
||||
// Should get the second filesystem
|
||||
retrieved, exists := fsys.Get(key)
|
||||
if !exists {
|
||||
t.Error("Filesystem should exist")
|
||||
}
|
||||
if retrieved != fs2 {
|
||||
t.Error("Should get the overwritten filesystem")
|
||||
}
|
||||
if retrieved == fs1 {
|
||||
t.Error("Should not get the original filesystem")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSystems_Concurrent_RegisterUnregister_SameKey(t *testing.T) {
|
||||
fsys := newMockFileSystems()
|
||||
key := "concurrent-key"
|
||||
|
||||
const numGoroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Half the goroutines register, half unregister
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
if i%2 == 0 {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
mockFS := &mockFileSystem{name: fmt.Sprintf("fs-%d", id)}
|
||||
fsys.Register(key, mockFS)
|
||||
}(i)
|
||||
} else {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
fsys.Unregister(key)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// The final state is unpredictable due to race conditions,
|
||||
// but the operations should not panic or cause corruption
|
||||
// Test passes if we reach here without issues
|
||||
}
|
||||
|
||||
func TestFileSystems_StressTest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping stress test in short mode")
|
||||
}
|
||||
|
||||
fsys := newMockFileSystems()
|
||||
|
||||
const numGoroutines = 100
|
||||
const duration = 100 * time.Millisecond
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start timer
|
||||
go func() {
|
||||
time.Sleep(duration)
|
||||
close(stopChan)
|
||||
}()
|
||||
|
||||
// Stress test with continuous operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
key := fmt.Sprintf("stress-fs-%d", id%10) // Use limited set of keys
|
||||
mockFS := &mockFileSystem{
|
||||
name: key,
|
||||
files: map[string]string{key + ".txt": "stress content"},
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
default:
|
||||
// Rapid register/get/unregister cycles
|
||||
fsys.Register(key, mockFS)
|
||||
|
||||
if retrieved, exists := fsys.Get(key); exists {
|
||||
// Try to use the filesystem
|
||||
if file, err := retrieved.Open(key + ".txt"); err == nil {
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
fsys.Unregister(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Test passes if we reach here without panics or deadlocks
|
||||
}
|
||||
|
|
@ -17,6 +17,37 @@ func TestSanitizeMethod(t *testing.T) {
|
|||
{method: "trace", expected: "TRACE"},
|
||||
{method: "UNKNOWN", expected: "OTHER"},
|
||||
{method: strings.Repeat("ohno", 9999), expected: "OTHER"},
|
||||
|
||||
// Test all standard HTTP methods in uppercase
|
||||
{method: "GET", expected: "GET"},
|
||||
{method: "HEAD", expected: "HEAD"},
|
||||
{method: "POST", expected: "POST"},
|
||||
{method: "PUT", expected: "PUT"},
|
||||
{method: "DELETE", expected: "DELETE"},
|
||||
{method: "CONNECT", expected: "CONNECT"},
|
||||
{method: "OPTIONS", expected: "OPTIONS"},
|
||||
{method: "TRACE", expected: "TRACE"},
|
||||
{method: "PATCH", expected: "PATCH"},
|
||||
|
||||
// Test all standard HTTP methods in lowercase
|
||||
{method: "get", expected: "GET"},
|
||||
{method: "head", expected: "HEAD"},
|
||||
{method: "post", expected: "POST"},
|
||||
{method: "put", expected: "PUT"},
|
||||
{method: "delete", expected: "DELETE"},
|
||||
{method: "connect", expected: "CONNECT"},
|
||||
{method: "options", expected: "OPTIONS"},
|
||||
{method: "trace", expected: "TRACE"},
|
||||
{method: "patch", expected: "PATCH"},
|
||||
|
||||
// Test mixed case and non-standard methods
|
||||
{method: "Get", expected: "OTHER"},
|
||||
{method: "gEt", expected: "OTHER"},
|
||||
{method: "UNKNOWN", expected: "OTHER"},
|
||||
{method: "PROPFIND", expected: "OTHER"},
|
||||
{method: "MKCOL", expected: "OTHER"},
|
||||
{method: "", expected: "OTHER"},
|
||||
{method: " ", expected: "OTHER"},
|
||||
}
|
||||
|
||||
for _, d := range tests {
|
||||
|
|
@ -26,3 +57,79 @@ func TestSanitizeMethod(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "zero returns 200",
|
||||
code: 0,
|
||||
expected: "200",
|
||||
},
|
||||
{
|
||||
name: "200 returns 200",
|
||||
code: 200,
|
||||
expected: "200",
|
||||
},
|
||||
{
|
||||
name: "404 returns 404",
|
||||
code: 404,
|
||||
expected: "404",
|
||||
},
|
||||
{
|
||||
name: "500 returns 500",
|
||||
code: 500,
|
||||
expected: "500",
|
||||
},
|
||||
{
|
||||
name: "301 returns 301",
|
||||
code: 301,
|
||||
expected: "301",
|
||||
},
|
||||
{
|
||||
name: "418 teapot returns 418",
|
||||
code: 418,
|
||||
expected: "418",
|
||||
},
|
||||
{
|
||||
name: "999 custom code",
|
||||
code: 999,
|
||||
expected: "999",
|
||||
},
|
||||
{
|
||||
name: "negative code",
|
||||
code: -1,
|
||||
expected: "-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeCode(tt.code)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeCode(%d) = %s; want %s", tt.code, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizeCode benchmarks the SanitizeCode function
|
||||
func BenchmarkSanitizeCode(b *testing.B) {
|
||||
codes := []int{0, 200, 404, 500, 301, 418}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SanitizeCode(codes[i%len(codes)])
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSanitizeMethod benchmarks the SanitizeMethod function
|
||||
func BenchmarkSanitizeMethod(b *testing.B) {
|
||||
methods := []string{"GET", "POST", "PUT", "DELETE", "UNKNOWN"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
SanitizeMethod(methods[i%len(methods)])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -441,7 +441,7 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui
|
|||
if end < start {
|
||||
return NetworkAddress{}, fmt.Errorf("end port must not be less than start port")
|
||||
}
|
||||
if (end - start) > maxPortSpan {
|
||||
if (end-start)+1 > maxPortSpan {
|
||||
return NetworkAddress{}, fmt.Errorf("port range exceeds %d ports", maxPortSpan)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
394
metrics_test.go
Normal file
394
metrics_test.go
Normal file
|
|
@ -0,0 +1,394 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func TestGlobalMetrics_ConfigSuccess(t *testing.T) {
|
||||
// Test setting config success metric
|
||||
originalValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
// Set to success
|
||||
globalMetrics.configSuccess.Set(1)
|
||||
newValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
if newValue != 1 {
|
||||
t.Errorf("Expected config success metric to be 1, got %f", newValue)
|
||||
}
|
||||
|
||||
// Set to failure
|
||||
globalMetrics.configSuccess.Set(0)
|
||||
failureValue := getMetricValue(globalMetrics.configSuccess)
|
||||
|
||||
if failureValue != 0 {
|
||||
t.Errorf("Expected config success metric to be 0, got %f", failureValue)
|
||||
}
|
||||
|
||||
// Restore original value if it existed
|
||||
if originalValue != 0 {
|
||||
globalMetrics.configSuccess.Set(originalValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalMetrics_ConfigSuccessTime(t *testing.T) {
|
||||
// Set success time
|
||||
globalMetrics.configSuccessTime.SetToCurrentTime()
|
||||
|
||||
// Get the metric value
|
||||
metricValue := getMetricValue(globalMetrics.configSuccessTime)
|
||||
|
||||
// Should be a reasonable Unix timestamp (not zero)
|
||||
if metricValue == 0 {
|
||||
t.Error("Config success time should not be zero")
|
||||
}
|
||||
|
||||
// Should be recent (within last minute)
|
||||
now := time.Now().Unix()
|
||||
if int64(metricValue) < now-60 || int64(metricValue) > now {
|
||||
t.Errorf("Config success time %f should be recent (now: %d)", metricValue, now)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminMetrics_RequestCount(t *testing.T) {
|
||||
// Initialize admin metrics for testing
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/config",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
// Get initial value
|
||||
initialValue := getCounterValue(adminMetrics.requestCount, labels)
|
||||
|
||||
// Increment counter
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
|
||||
// Verify increment
|
||||
newValue := getCounterValue(adminMetrics.requestCount, labels)
|
||||
if newValue != initialValue+1 {
|
||||
t.Errorf("Expected counter to increment by 1, got %f -> %f", initialValue, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminMetrics_RequestErrors(t *testing.T) {
|
||||
// Initialize admin metrics for testing
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/test",
|
||||
"method": "POST",
|
||||
}
|
||||
|
||||
// Get initial value
|
||||
initialValue := getCounterValue(adminMetrics.requestErrors, labels)
|
||||
|
||||
// Increment error counter
|
||||
adminMetrics.requestErrors.With(labels).Inc()
|
||||
|
||||
// Verify increment
|
||||
newValue := getCounterValue(adminMetrics.requestErrors, labels)
|
||||
if newValue != initialValue+1 {
|
||||
t.Errorf("Expected error counter to increment by 1, got %f -> %f", initialValue, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_ConcurrentAccess(t *testing.T) {
|
||||
// Initialize admin metrics
|
||||
initAdminMetrics()
|
||||
|
||||
const numGoroutines = 100
|
||||
const incrementsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "concurrent",
|
||||
"path": "/concurrent",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
initialCount := getCounterValue(adminMetrics.requestCount, labels)
|
||||
|
||||
// Concurrent increments
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < incrementsPerGoroutine; j++ {
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify final count
|
||||
finalCount := getCounterValue(adminMetrics.requestCount, labels)
|
||||
expectedIncrement := float64(numGoroutines * incrementsPerGoroutine)
|
||||
|
||||
if finalCount-initialCount != expectedIncrement {
|
||||
t.Errorf("Expected counter to increase by %f, got %f",
|
||||
expectedIncrement, finalCount-initialCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_LabelValidation(t *testing.T) {
|
||||
// Test various label combinations
|
||||
tests := []struct {
|
||||
name string
|
||||
labels prometheus.Labels
|
||||
metric string
|
||||
}{
|
||||
{
|
||||
name: "valid request count labels",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/test",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
{
|
||||
name: "valid error labels",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/error",
|
||||
"method": "POST",
|
||||
},
|
||||
metric: "requestErrors",
|
||||
},
|
||||
{
|
||||
name: "empty path",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "",
|
||||
"method": "GET",
|
||||
"code": "404",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
{
|
||||
name: "special characters in path",
|
||||
labels: prometheus.Labels{
|
||||
"handler": "test",
|
||||
"path": "/api/test%20with%20spaces",
|
||||
"method": "PUT",
|
||||
"code": "201",
|
||||
},
|
||||
metric: "requestCount",
|
||||
},
|
||||
}
|
||||
|
||||
initAdminMetrics()
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// This should not panic or error
|
||||
switch test.metric {
|
||||
case "requestCount":
|
||||
adminMetrics.requestCount.With(test.labels).Inc()
|
||||
case "requestErrors":
|
||||
adminMetrics.requestErrors.With(test.labels).Inc()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetrics_Initialization_Idempotent(t *testing.T) {
|
||||
// Test that initializing admin metrics multiple times is safe
|
||||
for i := 0; i < 5; i++ {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Iteration %d: initAdminMetrics panicked: %v", i, r)
|
||||
}
|
||||
}()
|
||||
initAdminMetrics()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentHandlerCounter(t *testing.T) {
|
||||
// Create a test counter with the expected labels
|
||||
counter := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "test_counter",
|
||||
Help: "Test counter for instrumentation",
|
||||
},
|
||||
[]string{"code", "method"},
|
||||
)
|
||||
|
||||
// Create instrumented handler
|
||||
testHandler := instrumentHandlerCounter(
|
||||
counter,
|
||||
&mockHTTPHandler{statusCode: 200},
|
||||
)
|
||||
|
||||
// Create test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get initial counter value
|
||||
initialValue := getCounterValue(counter, prometheus.Labels{"code": "200", "method": "GET"})
|
||||
|
||||
// Serve request
|
||||
testHandler.ServeHTTP(rr, req)
|
||||
|
||||
// Verify counter was incremented
|
||||
finalValue := getCounterValue(counter, prometheus.Labels{"code": "200", "method": "GET"})
|
||||
if finalValue != initialValue+1 {
|
||||
t.Errorf("Expected counter to increment by 1, got %f -> %f", initialValue, finalValue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentHandlerCounter_ErrorStatus(t *testing.T) {
|
||||
counter := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "test_error_counter",
|
||||
Help: "Test counter for error status",
|
||||
},
|
||||
[]string{"code", "method"},
|
||||
)
|
||||
|
||||
// Test different status codes
|
||||
statusCodes := []int{200, 404, 500, 301, 401}
|
||||
|
||||
for _, status := range statusCodes {
|
||||
t.Run(fmt.Sprintf("status_%d", status), func(t *testing.T) {
|
||||
handler := instrumentHandlerCounter(
|
||||
counter,
|
||||
&mockHTTPHandler{statusCode: status},
|
||||
)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
statusLabels := prometheus.Labels{"code": fmt.Sprintf("%d", status), "method": "GET"}
|
||||
initialValue := getCounterValue(counter, statusLabels)
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
finalValue := getCounterValue(counter, statusLabels)
|
||||
if finalValue != initialValue+1 {
|
||||
t.Errorf("Status %d: Expected counter increment", status)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func getMetricValue(gauge prometheus.Gauge) float64 {
|
||||
metric := &dto.Metric{}
|
||||
gauge.Write(metric)
|
||||
return metric.GetGauge().GetValue()
|
||||
}
|
||||
|
||||
func getCounterValue(counter *prometheus.CounterVec, labels prometheus.Labels) float64 {
|
||||
metric, err := counter.GetMetricWith(labels)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
pb := &dto.Metric{}
|
||||
metric.Write(pb)
|
||||
return pb.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
type mockHTTPHandler struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (m *mockHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(m.statusCode)
|
||||
}
|
||||
|
||||
func TestMetrics_Memory_Usage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory test in short mode")
|
||||
}
|
||||
|
||||
// Initialize metrics
|
||||
initAdminMetrics()
|
||||
|
||||
// Create many different label combinations
|
||||
const numLabels = 1000
|
||||
|
||||
for i := 0; i < numLabels; i++ {
|
||||
labels := prometheus.Labels{
|
||||
"handler": fmt.Sprintf("handler_%d", i%10),
|
||||
"path": fmt.Sprintf("/path_%d", i),
|
||||
"method": []string{"GET", "POST", "PUT", "DELETE"}[i%4],
|
||||
"code": []string{"200", "404", "500"}[i%3],
|
||||
}
|
||||
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
|
||||
// Also increment error counter occasionally
|
||||
if i%10 == 0 {
|
||||
errorLabels := prometheus.Labels{
|
||||
"handler": labels["handler"],
|
||||
"path": labels["path"],
|
||||
"method": labels["method"],
|
||||
}
|
||||
adminMetrics.requestErrors.With(errorLabels).Inc()
|
||||
}
|
||||
}
|
||||
|
||||
// Test passes if we don't run out of memory or panic
|
||||
}
|
||||
|
||||
func BenchmarkGlobalMetrics_ConfigSuccess(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalMetrics.configSuccess.Set(float64(i % 2))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGlobalMetrics_ConfigSuccessTime(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
globalMetrics.configSuccessTime.SetToCurrentTime()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdminMetrics_RequestCount_WithLabels(b *testing.B) {
|
||||
initAdminMetrics()
|
||||
|
||||
labels := prometheus.Labels{
|
||||
"handler": "benchmark",
|
||||
"path": "/benchmark",
|
||||
"method": "GET",
|
||||
"code": "200",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
adminMetrics.requestCount.With(labels).Inc()
|
||||
}
|
||||
}
|
||||
|
|
@ -85,8 +85,11 @@ func (e HandlerError) Unwrap() error { return e.Err }
|
|||
// randString returns a string of n random characters.
|
||||
// It is not even remotely secure OR a proper distribution.
|
||||
// But it's good enough for some things. It excludes certain
|
||||
// confusing characters like I, l, 1, 0, O, etc. If sameCase
|
||||
// is true, then uppercase letters are excluded.
|
||||
// confusing characters like I, l, 1, 0, O. If sameCase
|
||||
// is true, then uppercase letters are excluded as well as
|
||||
// the characters l and o. If sameCase is false, both uppercase
|
||||
// and lowercase letters are used, and the characters I, l, 1, 0, O
|
||||
// are excluded.
|
||||
func randString(n int, sameCase bool) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
|
|
|
|||
279
modules/caddyhttp/errors_utils_test.go
Normal file
279
modules/caddyhttp/errors_utils_test.go
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddyhttp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func TestRandString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
sameCase bool
|
||||
wantLen int
|
||||
checkCase func(string) bool
|
||||
}{
|
||||
{
|
||||
name: "zero length",
|
||||
length: 0,
|
||||
sameCase: false,
|
||||
wantLen: 0,
|
||||
checkCase: func(s string) bool {
|
||||
return s == ""
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative length",
|
||||
length: -5,
|
||||
sameCase: false,
|
||||
wantLen: 0,
|
||||
checkCase: func(s string) bool {
|
||||
return s == ""
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single character mixed case",
|
||||
length: 1,
|
||||
sameCase: false,
|
||||
wantLen: 1,
|
||||
checkCase: func(s string) bool {
|
||||
// Should be alphanumeric
|
||||
return len(s) == 1 && (unicode.IsLetter(rune(s[0])) || unicode.IsDigit(rune(s[0])))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single character same case",
|
||||
length: 1,
|
||||
sameCase: true,
|
||||
wantLen: 1,
|
||||
checkCase: func(s string) bool {
|
||||
// Should be lowercase or digit
|
||||
return len(s) == 1 && (unicode.IsLower(rune(s[0])) || unicode.IsDigit(rune(s[0])))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "short string mixed case",
|
||||
length: 5,
|
||||
sameCase: false,
|
||||
wantLen: 5,
|
||||
checkCase: func(s string) bool {
|
||||
// All characters should be alphanumeric
|
||||
for _, c := range s {
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "short string same case",
|
||||
length: 5,
|
||||
sameCase: true,
|
||||
wantLen: 5,
|
||||
checkCase: func(s string) bool {
|
||||
// All characters should be lowercase or digits
|
||||
for _, c := range s {
|
||||
if unicode.IsUpper(c) {
|
||||
return false
|
||||
}
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "medium string mixed case",
|
||||
length: 20,
|
||||
sameCase: false,
|
||||
wantLen: 20,
|
||||
checkCase: func(s string) bool {
|
||||
for _, c := range s {
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "long string same case",
|
||||
length: 100,
|
||||
sameCase: true,
|
||||
wantLen: 100,
|
||||
checkCase: func(s string) bool {
|
||||
for _, c := range s {
|
||||
if unicode.IsUpper(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := randString(tt.length, tt.sameCase)
|
||||
|
||||
// Check length
|
||||
if len(result) != tt.wantLen {
|
||||
t.Errorf("randString(%d, %v) length = %d, want %d",
|
||||
tt.length, tt.sameCase, len(result), tt.wantLen)
|
||||
}
|
||||
|
||||
// Check case requirements
|
||||
if !tt.checkCase(result) {
|
||||
t.Errorf("randString(%d, %v) = %q failed case check",
|
||||
tt.length, tt.sameCase, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_NoConfusingChars ensures that confusing characters
|
||||
// like I, l, 1, 0, O are excluded from the generated strings
|
||||
func TestRandString_NoConfusingChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
excluded []rune
|
||||
}{
|
||||
{
|
||||
name: "mixed case excludes I,l,1,0,O",
|
||||
sameCase: false,
|
||||
excluded: []rune{'I', 'l', '1', '0', 'O'},
|
||||
},
|
||||
{
|
||||
name: "same case excludes l,0",
|
||||
sameCase: true,
|
||||
excluded: []rune{'l', 'o'},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Generate multiple strings to increase confidence
|
||||
for i := 0; i < 100; i++ {
|
||||
result := randString(50, tt.sameCase)
|
||||
|
||||
for _, char := range tt.excluded {
|
||||
if strings.ContainsRune(result, char) {
|
||||
t.Errorf("randString(50, %v) contains excluded character %q in %q",
|
||||
tt.sameCase, char, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_Uniqueness verifies that consecutive calls produce
|
||||
// different strings (with high probability)
|
||||
func TestRandString_Uniqueness(t *testing.T) {
|
||||
const iterations = 100
|
||||
const length = 16
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
}{
|
||||
{"mixed case", false},
|
||||
{"same case", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
seen := make(map[string]bool)
|
||||
duplicates := 0
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
result := randString(length, tt.sameCase)
|
||||
if seen[result] {
|
||||
duplicates++
|
||||
}
|
||||
seen[result] = true
|
||||
}
|
||||
|
||||
// With a 16-character string from a large alphabet, duplicates should be extremely rare
|
||||
// Allow at most 1 duplicate in 100 iterations
|
||||
if duplicates > 1 {
|
||||
t.Errorf("randString(%d, %v) produced %d duplicates in %d iterations (expected ≤1)",
|
||||
length, tt.sameCase, duplicates, iterations)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRandString_CharacterDistribution checks that the generated strings
|
||||
// contain a reasonable mix of characters (not just one character)
|
||||
func TestRandString_CharacterDistribution(t *testing.T) {
|
||||
const length = 1000
|
||||
const minUniqueChars = 15 // Should have at least 15 different characters in 1000 chars
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sameCase bool
|
||||
}{
|
||||
{"mixed case", false},
|
||||
{"same case", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := randString(length, tt.sameCase)
|
||||
|
||||
uniqueChars := make(map[rune]bool)
|
||||
for _, c := range result {
|
||||
uniqueChars[c] = true
|
||||
}
|
||||
|
||||
if len(uniqueChars) < minUniqueChars {
|
||||
t.Errorf("randString(%d, %v) produced only %d unique characters (expected ≥%d)",
|
||||
length, tt.sameCase, len(uniqueChars), minUniqueChars)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRandString measures the performance of random string generation
|
||||
func BenchmarkRandString(b *testing.B) {
|
||||
benchmarks := []struct {
|
||||
name string
|
||||
length int
|
||||
sameCase bool
|
||||
}{
|
||||
{"short_mixed", 8, false},
|
||||
{"short_same", 8, true},
|
||||
{"medium_mixed", 32, false},
|
||||
{"medium_same", 32, true},
|
||||
{"long_mixed", 128, false},
|
||||
{"long_same", 128, true},
|
||||
}
|
||||
|
||||
for _, bm := range benchmarks {
|
||||
b.Run(bm.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = randString(bm.length, bm.sameCase)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
194
modules/caddyhttp/rewrite/rewrite_utils_test.go
Normal file
194
modules/caddyhttp/rewrite/rewrite_utils_test.go
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package rewrite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReverse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple string",
|
||||
input: "hello",
|
||||
expected: "olleh",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single character",
|
||||
input: "a",
|
||||
expected: "a",
|
||||
},
|
||||
{
|
||||
name: "two characters",
|
||||
input: "ab",
|
||||
expected: "ba",
|
||||
},
|
||||
{
|
||||
name: "palindrome",
|
||||
input: "racecar",
|
||||
expected: "racecar",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: "hello world",
|
||||
expected: "dlrow olleh",
|
||||
},
|
||||
{
|
||||
name: "with numbers",
|
||||
input: "abc123",
|
||||
expected: "321cba",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
input: "hello世界",
|
||||
expected: "界世olleh",
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "🎉🎊🎈",
|
||||
expected: "🎈🎊🎉",
|
||||
},
|
||||
{
|
||||
name: "mixed unicode and ascii",
|
||||
input: "café☕",
|
||||
expected: "☕éfac",
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "a!b@c#d$",
|
||||
expected: "$d#c@b!a",
|
||||
},
|
||||
{
|
||||
name: "path-like string",
|
||||
input: "/path/to/file",
|
||||
expected: "elif/ot/htap/",
|
||||
},
|
||||
{
|
||||
name: "url-like string",
|
||||
input: "https://example.com",
|
||||
expected: "moc.elpmaxe//:sptth",
|
||||
},
|
||||
{
|
||||
name: "long string",
|
||||
input: "The quick brown fox jumps over the lazy dog",
|
||||
expected: "god yzal eht revo spmuj xof nworb kciuq ehT",
|
||||
},
|
||||
{
|
||||
name: "newlines",
|
||||
input: "line1\nline2\nline3",
|
||||
expected: "3enil\n2enil\n1enil",
|
||||
},
|
||||
{
|
||||
name: "tabs",
|
||||
input: "a\tb\tc",
|
||||
expected: "c\tb\ta",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverse(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("reverse(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
|
||||
// Test that reversing twice gives the original string
|
||||
if tt.input != "" {
|
||||
doubleReverse := reverse(reverse(tt.input))
|
||||
if doubleReverse != tt.input {
|
||||
t.Errorf("reverse(reverse(%q)) = %q; want %q", tt.input, doubleReverse, tt.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverse_LengthPreservation(t *testing.T) {
|
||||
// Test that reverse preserves string length
|
||||
testStrings := []string{
|
||||
"",
|
||||
"a",
|
||||
"ab",
|
||||
"abc",
|
||||
"hello world",
|
||||
"🎉🎊🎈",
|
||||
"café☕",
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
}
|
||||
|
||||
for _, s := range testStrings {
|
||||
reversed := reverse(s)
|
||||
if len([]rune(s)) != len([]rune(reversed)) {
|
||||
t.Errorf("reverse(%q) changed length: original %d, reversed %d", s, len([]rune(s)), len([]rune(reversed)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkReverse benchmarks the reverse function
|
||||
func BenchmarkReverse(b *testing.B) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"short", "hello"},
|
||||
{"medium", "The quick brown fox jumps over the lazy dog"},
|
||||
{"long", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."},
|
||||
{"unicode", "hello世界🎉"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
reverse(tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverse_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{"null byte", "\x00"},
|
||||
{"multiple null bytes", "\x00\x00\x00"},
|
||||
{"control characters", "\t\n\r"},
|
||||
{"high unicode", "𝕳𝖊𝖑𝖑𝖔"},
|
||||
{"zero-width characters", "a\u200Bb\u200Cc"},
|
||||
{"combining characters", "é"}, // e + combining acute
|
||||
{"rtl text", "مرحبا"},
|
||||
{"mixed rtl/ltr", "Hello مرحبا World"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reverse(tt.input)
|
||||
// Just ensure it doesn't panic and returns something
|
||||
if result == "" && tt.input != "" {
|
||||
t.Errorf("reverse(%q) returned empty string", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
963
network_test.go
Normal file
963
network_test.go
Normal file
|
|
@ -0,0 +1,963 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNetworkAddress_String_Consistency(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
}{
|
||||
{
|
||||
name: "basic tcp",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8080},
|
||||
},
|
||||
{
|
||||
name: "tcp with port range",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost", StartPort: 8080, EndPort: 8090},
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
|
||||
},
|
||||
{
|
||||
name: "udp",
|
||||
addr: NetworkAddress{Network: "udp", Host: "0.0.0.0", StartPort: 53, EndPort: 53},
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::1", StartPort: 80, EndPort: 80},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
str := test.addr.String()
|
||||
|
||||
// Parse the string back
|
||||
parsed, err := ParseNetworkAddress(str)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse string representation: %v", err)
|
||||
}
|
||||
|
||||
// Should be equivalent to original
|
||||
if parsed.Network != test.addr.Network {
|
||||
t.Errorf("Network mismatch: expected %s, got %s", test.addr.Network, parsed.Network)
|
||||
}
|
||||
if parsed.Host != test.addr.Host {
|
||||
t.Errorf("Host mismatch: expected %s, got %s", test.addr.Host, parsed.Host)
|
||||
}
|
||||
if parsed.StartPort != test.addr.StartPort {
|
||||
t.Errorf("StartPort mismatch: expected %d, got %d", test.addr.StartPort, parsed.StartPort)
|
||||
}
|
||||
if parsed.EndPort != test.addr.EndPort {
|
||||
t.Errorf("EndPort mismatch: expected %d, got %d", test.addr.EndPort, parsed.EndPort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_PortRangeSize_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected uint
|
||||
}{
|
||||
{
|
||||
name: "single port",
|
||||
addr: NetworkAddress{StartPort: 80, EndPort: 80},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "invalid range (end < start)",
|
||||
addr: NetworkAddress{StartPort: 8080, EndPort: 8070},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "zero ports",
|
||||
addr: NetworkAddress{StartPort: 0, EndPort: 0},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "maximum range",
|
||||
addr: NetworkAddress{StartPort: 1, EndPort: 65535},
|
||||
expected: 65535,
|
||||
},
|
||||
{
|
||||
name: "large range",
|
||||
addr: NetworkAddress{StartPort: 8000, EndPort: 9000},
|
||||
expected: 1001,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
size := test.addr.PortRangeSize()
|
||||
if size != test.expected {
|
||||
t.Errorf("Expected %d, got %d", test.expected, size)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_At_Validation(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
// Test valid offsets
|
||||
for offset := uint(0); offset <= 10; offset++ {
|
||||
result := addr.At(offset)
|
||||
expectedPort := 8080 + offset
|
||||
|
||||
if result.StartPort != expectedPort || result.EndPort != expectedPort {
|
||||
t.Errorf("Offset %d: expected port %d, got %d-%d",
|
||||
offset, expectedPort, result.StartPort, result.EndPort)
|
||||
}
|
||||
|
||||
if result.Network != addr.Network || result.Host != addr.Host {
|
||||
t.Errorf("Offset %d: network/host should be preserved", offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Expand_LargeRange(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
EndPort: 8010,
|
||||
}
|
||||
|
||||
expanded := addr.Expand()
|
||||
expectedSize := 11 // 8000 to 8010 inclusive
|
||||
|
||||
if len(expanded) != expectedSize {
|
||||
t.Errorf("Expected %d addresses, got %d", expectedSize, len(expanded))
|
||||
}
|
||||
|
||||
// Verify each address
|
||||
for i, expandedAddr := range expanded {
|
||||
expectedPort := uint(8000 + i)
|
||||
if expandedAddr.StartPort != expectedPort || expandedAddr.EndPort != expectedPort {
|
||||
t.Errorf("Address %d: expected port %d, got %d-%d",
|
||||
i, expectedPort, expandedAddr.StartPort, expandedAddr.EndPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IsLoopback_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "unix socket",
|
||||
addr: NetworkAddress{Network: "unix", Host: "/tmp/socket"},
|
||||
expected: true, // Unix sockets are always considered loopback
|
||||
},
|
||||
{
|
||||
name: "fd network",
|
||||
addr: NetworkAddress{Network: "fd", Host: "3"},
|
||||
expected: true, // fd networks are always considered loopback
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "::1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "127.0.0.2",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "127.0.0.2"},
|
||||
expected: true, // Part of 127.0.0.0/8 loopback range
|
||||
},
|
||||
{
|
||||
name: "192.168.1.1",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
|
||||
expected: false, // Private but not loopback
|
||||
},
|
||||
{
|
||||
name: "invalid ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "invalid-ip"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
addr: NetworkAddress{Network: "tcp", Host: ""},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.isLoopback()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IsWildcard_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "empty host",
|
||||
addr: NetworkAddress{Network: "tcp", Host: ""},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv4 any",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "0.0.0.0"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 any",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "::"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "localhost"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "specific ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "192.168.1.1"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "invalid ip",
|
||||
addr: NetworkAddress{Network: "tcp", Host: "invalid"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.isWildcardInterface()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected %v, got %v", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitNetworkAddress_IPv6_EdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectNetwork string
|
||||
expectHost string
|
||||
expectPort string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "ipv6 with port",
|
||||
input: "[::1]:8080",
|
||||
expectHost: "::1",
|
||||
expectPort: "8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6 without port",
|
||||
input: "[::1]",
|
||||
expectHost: "::1",
|
||||
},
|
||||
{
|
||||
name: "ipv6 without brackets or port",
|
||||
input: "::1",
|
||||
expectHost: "::1",
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback",
|
||||
input: "[::1]:443",
|
||||
expectHost: "::1",
|
||||
expectPort: "443",
|
||||
},
|
||||
{
|
||||
name: "ipv6 any address",
|
||||
input: "[::]:80",
|
||||
expectHost: "::",
|
||||
expectPort: "80",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with network prefix",
|
||||
input: "tcp6/[::1]:8080",
|
||||
expectNetwork: "tcp6",
|
||||
expectHost: "::1",
|
||||
expectPort: "8080",
|
||||
},
|
||||
{
|
||||
name: "malformed ipv6",
|
||||
input: "[::1:8080", // Missing closing bracket
|
||||
expectHost: "::1:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with zone",
|
||||
input: "[fe80::1%eth0]:8080",
|
||||
expectHost: "fe80::1%eth0",
|
||||
expectPort: "8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
network, host, port, err := SplitNetworkAddress(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if network != test.expectNetwork {
|
||||
t.Errorf("Network: expected '%s', got '%s'", test.expectNetwork, network)
|
||||
}
|
||||
if host != test.expectHost {
|
||||
t.Errorf("Host: expected '%s', got '%s'", test.expectHost, host)
|
||||
}
|
||||
if port != test.expectPort {
|
||||
t.Errorf("Port: expected '%s', got '%s'", test.expectPort, port)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNetworkAddress_PortRange_Validation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid range",
|
||||
input: "localhost:8080-8090",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "inverted range",
|
||||
input: "localhost:8090-8080",
|
||||
expectErr: true,
|
||||
errMsg: "end port must not be less than start port",
|
||||
},
|
||||
{
|
||||
name: "too large range",
|
||||
input: "localhost:0-65535",
|
||||
expectErr: true,
|
||||
errMsg: "port range exceeds 65535 ports",
|
||||
},
|
||||
{
|
||||
name: "invalid start port",
|
||||
input: "localhost:abc-8080",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid end port",
|
||||
input: "localhost:8080-xyz",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "port too large",
|
||||
input: "localhost:99999",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "negative port",
|
||||
input: "localhost:-80",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
_, err := ParseNetworkAddress(test.input)
|
||||
|
||||
if test.expectErr && err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
if !test.expectErr && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if test.expectErr && test.errMsg != "" && err != nil {
|
||||
if !containsString(err.Error(), test.errMsg) {
|
||||
t.Errorf("Expected error containing '%s', got '%s'", test.errMsg, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Listen_ContextCancellation(t *testing.T) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 0, // Let OS assign port
|
||||
EndPort: 0,
|
||||
}
|
||||
|
||||
// Create context that will be cancelled
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Start listening in a goroutine
|
||||
listenDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := addr.Listen(ctx, 0, net.ListenConfig{})
|
||||
listenDone <- err
|
||||
}()
|
||||
|
||||
// Cancel context immediately
|
||||
cancel()
|
||||
|
||||
// Should get context cancellation error quickly
|
||||
select {
|
||||
case err := <-listenDone:
|
||||
if err == nil {
|
||||
t.Error("Expected error due to context cancellation")
|
||||
}
|
||||
// Accept any error related to context cancellation
|
||||
// (could be context.Canceled or DNS lookup error due to cancellation)
|
||||
case <-time.After(time.Second):
|
||||
t.Error("Listen operation did not respect context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_ListenAll_PartialFailure(t *testing.T) {
|
||||
// Create an address range where some ports might fail to bind
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 0, // OS-assigned port
|
||||
EndPort: 2, // Try to bind 3 ports starting from OS-assigned
|
||||
}
|
||||
|
||||
// This test might be flaky depending on available ports,
|
||||
// but tests the error handling logic
|
||||
ctx := context.Background()
|
||||
|
||||
listeners, err := addr.ListenAll(ctx, net.ListenConfig{})
|
||||
|
||||
// Either all succeed or all fail (due to cleanup on partial failure)
|
||||
if err != nil {
|
||||
// If there's an error, no listeners should be returned
|
||||
if len(listeners) != 0 {
|
||||
t.Errorf("Expected no listeners on error, got %d", len(listeners))
|
||||
}
|
||||
} else {
|
||||
// If successful, should have listeners for all ports in range
|
||||
expectedCount := int(addr.PortRangeSize())
|
||||
if len(listeners) != expectedCount {
|
||||
t.Errorf("Expected %d listeners, got %d", expectedCount, len(listeners))
|
||||
}
|
||||
|
||||
// Clean up listeners
|
||||
for _, ln := range listeners {
|
||||
if closer, ok := ln.(interface{ Close() error }); ok {
|
||||
closer.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinNetworkAddress_SpecialCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
host string
|
||||
port string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty everything",
|
||||
network: "",
|
||||
host: "",
|
||||
port: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "network only",
|
||||
network: "tcp",
|
||||
host: "",
|
||||
port: "",
|
||||
expected: "tcp/",
|
||||
},
|
||||
{
|
||||
name: "host only",
|
||||
network: "",
|
||||
host: "localhost",
|
||||
port: "",
|
||||
expected: "localhost",
|
||||
},
|
||||
{
|
||||
name: "port only",
|
||||
network: "",
|
||||
host: "",
|
||||
port: "8080",
|
||||
expected: ":8080",
|
||||
},
|
||||
{
|
||||
name: "unix socket with port (port ignored)",
|
||||
network: "unix",
|
||||
host: "/tmp/socket",
|
||||
port: "8080",
|
||||
expected: "unix//tmp/socket",
|
||||
},
|
||||
{
|
||||
name: "fd network with port (port ignored)",
|
||||
network: "fd",
|
||||
host: "3",
|
||||
port: "8080",
|
||||
expected: "fd/3",
|
||||
},
|
||||
{
|
||||
name: "ipv6 host with port",
|
||||
network: "tcp",
|
||||
host: "::1",
|
||||
port: "8080",
|
||||
expected: "tcp/[::1]:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := JoinNetworkAddress(test.network, test.host, test.port)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsUnixNetwork_IsFdNetwork(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
isUnix bool
|
||||
isFd bool
|
||||
}{
|
||||
{"unix", true, false},
|
||||
{"unixgram", true, false},
|
||||
{"unixpacket", true, false},
|
||||
{"fd", false, true},
|
||||
{"fdgram", false, true},
|
||||
{"tcp", false, false},
|
||||
{"udp", false, false},
|
||||
{"", false, false},
|
||||
{"unix-like", true, false},
|
||||
{"fd-like", false, true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.network, func(t *testing.T) {
|
||||
if IsUnixNetwork(test.network) != test.isUnix {
|
||||
t.Errorf("IsUnixNetwork('%s'): expected %v, got %v",
|
||||
test.network, test.isUnix, IsUnixNetwork(test.network))
|
||||
}
|
||||
if IsFdNetwork(test.network) != test.isFd {
|
||||
t.Errorf("IsFdNetwork('%s'): expected %v, got %v",
|
||||
test.network, test.isFd, IsFdNetwork(test.network))
|
||||
}
|
||||
|
||||
// Test NetworkAddress methods too
|
||||
addr := NetworkAddress{Network: test.network}
|
||||
if addr.IsUnixNetwork() != test.isUnix {
|
||||
t.Errorf("NetworkAddress.IsUnixNetwork(): expected %v, got %v",
|
||||
test.isUnix, addr.IsUnixNetwork())
|
||||
}
|
||||
if addr.IsFdNetwork() != test.isFd {
|
||||
t.Errorf("NetworkAddress.IsFdNetwork(): expected %v, got %v",
|
||||
test.isFd, addr.IsFdNetwork())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterNetwork_Validation(t *testing.T) {
|
||||
// Save original state
|
||||
originalNetworkTypes := make(map[string]ListenerFunc)
|
||||
for k, v := range networkTypes {
|
||||
originalNetworkTypes[k] = v
|
||||
}
|
||||
defer func() {
|
||||
// Restore original state
|
||||
networkTypes = originalNetworkTypes
|
||||
}()
|
||||
|
||||
mockListener := func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Test reserved network types that should panic
|
||||
reservedTypes := []string{
|
||||
"tcp", "tcp4", "tcp6",
|
||||
"udp", "udp4", "udp6",
|
||||
"unix", "unixpacket", "unixgram",
|
||||
"ip:1", "ip4:1", "ip6:1",
|
||||
"fd", "fdgram",
|
||||
}
|
||||
|
||||
for _, networkType := range reservedTypes {
|
||||
t.Run("reserved_"+networkType, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Expected panic for reserved network type: %s", networkType)
|
||||
}
|
||||
}()
|
||||
RegisterNetwork(networkType, mockListener)
|
||||
})
|
||||
}
|
||||
|
||||
// Test valid registration
|
||||
t.Run("valid_registration", func(t *testing.T) {
|
||||
customNetwork := "custom-network"
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
|
||||
if _, exists := networkTypes[customNetwork]; !exists {
|
||||
t.Error("Custom network should be registered")
|
||||
}
|
||||
})
|
||||
|
||||
// Test duplicate registration should panic
|
||||
t.Run("duplicate_registration", func(t *testing.T) {
|
||||
customNetwork := "another-custom"
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("Expected panic for duplicate registration")
|
||||
}
|
||||
}()
|
||||
RegisterNetwork(customNetwork, mockListener)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListenerUsage_EdgeCases(t *testing.T) {
|
||||
// Test ListenerUsage function with various inputs
|
||||
tests := []struct {
|
||||
name string
|
||||
network string
|
||||
addr string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "non-existent listener",
|
||||
network: "tcp",
|
||||
addr: "localhost:9999",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "empty network and address",
|
||||
network: "",
|
||||
addr: "",
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
network: "unix",
|
||||
addr: "/tmp/non-existent.sock",
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
usage := ListenerUsage(test.network, test.addr)
|
||||
if usage != test.expected {
|
||||
t.Errorf("Expected usage %d, got %d", test.expected, usage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_Port_Formatting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single port",
|
||||
addr: NetworkAddress{StartPort: 80, EndPort: 80},
|
||||
expected: "80",
|
||||
},
|
||||
{
|
||||
name: "port range",
|
||||
addr: NetworkAddress{StartPort: 8080, EndPort: 8090},
|
||||
expected: "8080-8090",
|
||||
},
|
||||
{
|
||||
name: "zero ports",
|
||||
addr: NetworkAddress{StartPort: 0, EndPort: 0},
|
||||
expected: "0",
|
||||
},
|
||||
{
|
||||
name: "large ports",
|
||||
addr: NetworkAddress{StartPort: 65534, EndPort: 65535},
|
||||
expected: "65534-65535",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.port()
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_JoinHostPort_SpecialNetworks(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr NetworkAddress
|
||||
offset uint
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "unix socket ignores offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "unix",
|
||||
Host: "/tmp/socket",
|
||||
},
|
||||
offset: 100,
|
||||
expected: "/tmp/socket",
|
||||
},
|
||||
{
|
||||
name: "fd network ignores offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "fd",
|
||||
Host: "3",
|
||||
},
|
||||
offset: 50,
|
||||
expected: "3",
|
||||
},
|
||||
{
|
||||
name: "tcp with offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
},
|
||||
offset: 10,
|
||||
expected: "localhost:8010",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with offset",
|
||||
addr: NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "::1",
|
||||
StartPort: 8000,
|
||||
},
|
||||
offset: 5,
|
||||
expected: "[::1]:8005",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.addr.JoinHostPort(test.offset)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for string containment check
|
||||
func containsString(haystack, needle string) bool {
|
||||
return len(haystack) >= len(needle) &&
|
||||
(needle == "" || haystack == needle ||
|
||||
strings.Contains(haystack, needle))
|
||||
}
|
||||
|
||||
func TestListenerKey_Generation(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
addr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
network: "tcp",
|
||||
addr: "localhost:8080",
|
||||
expected: "tcp/localhost:8080",
|
||||
},
|
||||
{
|
||||
network: "unix",
|
||||
addr: "/tmp/socket",
|
||||
expected: "unix//tmp/socket",
|
||||
},
|
||||
{
|
||||
network: "",
|
||||
addr: "localhost:8080",
|
||||
expected: "/localhost:8080",
|
||||
},
|
||||
{
|
||||
network: "tcp",
|
||||
addr: "",
|
||||
expected: "tcp/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s_%s", test.network, test.addr), func(t *testing.T) {
|
||||
result := listenerKey(test.network, test.addr)
|
||||
if result != test.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", test.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkAddress_ConcurrentAccess(t *testing.T) {
|
||||
// Test that NetworkAddress methods are safe for concurrent read access
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
const numGoroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Call various methods concurrently
|
||||
_ = addr.String()
|
||||
_ = addr.PortRangeSize()
|
||||
_ = addr.IsUnixNetwork()
|
||||
_ = addr.IsFdNetwork()
|
||||
_ = addr.isLoopback()
|
||||
_ = addr.isWildcardInterface()
|
||||
_ = addr.port()
|
||||
_ = addr.JoinHostPort(uint(id % 10))
|
||||
_ = addr.At(uint(id % 11))
|
||||
|
||||
// Expand creates new slice, should be safe
|
||||
expanded := addr.Expand()
|
||||
if len(expanded) == 0 {
|
||||
t.Errorf("Goroutine %d: Expected non-empty expansion", id)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestNetworkAddress_IPv6_Zone_Handling(t *testing.T) {
|
||||
// Test IPv6 addresses with zone identifiers
|
||||
input := "tcp/[fe80::1%eth0]:8080"
|
||||
|
||||
addr, err := ParseNetworkAddress(input)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse IPv6 with zone: %v", err)
|
||||
}
|
||||
|
||||
if addr.Network != "tcp" {
|
||||
t.Errorf("Expected network 'tcp', got '%s'", addr.Network)
|
||||
}
|
||||
if addr.Host != "fe80::1%eth0" {
|
||||
t.Errorf("Expected host 'fe80::1%%eth0', got '%s'", addr.Host)
|
||||
}
|
||||
if addr.StartPort != 8080 {
|
||||
t.Errorf("Expected port 8080, got %d", addr.StartPort)
|
||||
}
|
||||
|
||||
// Test string representation round-trip
|
||||
str := addr.String()
|
||||
parsed, err := ParseNetworkAddress(str)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse string representation: %v", err)
|
||||
}
|
||||
|
||||
if parsed.Host != addr.Host {
|
||||
t.Errorf("Round-trip failed: expected host '%s', got '%s'", addr.Host, parsed.Host)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseNetworkAddress(b *testing.B) {
|
||||
inputs := []string{
|
||||
"localhost:8080",
|
||||
"tcp/localhost:8080-8090",
|
||||
"unix//tmp/socket",
|
||||
"[::1]:443",
|
||||
"udp/:53",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
input := inputs[i%len(inputs)]
|
||||
ParseNetworkAddress(input)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkAddress_String(b *testing.B) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8080,
|
||||
EndPort: 8090,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
addr.String()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNetworkAddress_Expand(b *testing.B) {
|
||||
addr := NetworkAddress{
|
||||
Network: "tcp",
|
||||
Host: "localhost",
|
||||
StartPort: 8000,
|
||||
EndPort: 8100, // 101 addresses
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
addr.Expand()
|
||||
}
|
||||
}
|
||||
1113
storage_test.go
Normal file
1113
storage_test.go
Normal file
File diff suppressed because it is too large
Load diff
624
usagepool_test.go
Normal file
624
usagepool_test.go
Normal file
|
|
@ -0,0 +1,624 @@
|
|||
// Copyright 2015 Matthew Holt and The Caddy Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockDestructor struct {
|
||||
value string
|
||||
destroyed int32
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockDestructor) Destruct() error {
|
||||
atomic.StoreInt32(&m.destroyed, 1)
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (m *mockDestructor) IsDestroyed() bool {
|
||||
return atomic.LoadInt32(&m.destroyed) == 1
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrNew_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
|
||||
// First load should construct new value
|
||||
val, loaded, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return &mockDestructor{value: "test-value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for new value")
|
||||
}
|
||||
if val.(*mockDestructor).value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got '%s'", val.(*mockDestructor).value)
|
||||
}
|
||||
|
||||
// Second load should return existing value
|
||||
val2, loaded2, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
t.Error("Constructor should not be called for existing value")
|
||||
return nil, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if !loaded2 {
|
||||
t.Error("Expected loaded to be true for existing value")
|
||||
}
|
||||
if val2.(*mockDestructor).value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got '%s'", val2.(*mockDestructor).value)
|
||||
}
|
||||
|
||||
// Check reference count
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != 2 {
|
||||
t.Errorf("Expected 2 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrNew_ConstructorError(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
expectedErr := errors.New("constructor failed")
|
||||
|
||||
val, loaded, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return nil, expectedErr
|
||||
})
|
||||
if err != expectedErr {
|
||||
t.Errorf("Expected constructor error, got: %v", err)
|
||||
}
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for failed construction")
|
||||
}
|
||||
if val != nil {
|
||||
t.Error("Expected nil value for failed construction")
|
||||
}
|
||||
|
||||
// Key should not exist after constructor failure
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after constructor failure")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrStore_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "stored-value"}
|
||||
|
||||
// First load/store should store new value
|
||||
val, loaded := pool.LoadOrStore(key, mockVal)
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for new value")
|
||||
}
|
||||
if val != mockVal {
|
||||
t.Error("Expected stored value to be returned")
|
||||
}
|
||||
|
||||
// Second load/store should return existing value
|
||||
newMockVal := &mockDestructor{value: "new-value"}
|
||||
val2, loaded2 := pool.LoadOrStore(key, newMockVal)
|
||||
if !loaded2 {
|
||||
t.Error("Expected loaded to be true for existing value")
|
||||
}
|
||||
if val2 != mockVal {
|
||||
t.Error("Expected original stored value to be returned")
|
||||
}
|
||||
|
||||
// Check reference count
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != 2 {
|
||||
t.Errorf("Expected 2 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_Basic(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
// Store value twice to get ref count of 2
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
// First delete should decrement ref count
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if deleted {
|
||||
t.Error("Expected deleted to be false when refs > 0")
|
||||
}
|
||||
if mockVal.IsDestroyed() {
|
||||
t.Error("Value should not be destroyed yet")
|
||||
}
|
||||
|
||||
// Second delete should destroy value
|
||||
deleted, err = pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Error("Expected deleted to be true when refs = 0")
|
||||
}
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed")
|
||||
}
|
||||
|
||||
// Key should not exist after deletion
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after deletion")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_NonExistentKey(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
deleted, err := pool.Delete("non-existent")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for non-existent key, got: %v", err)
|
||||
}
|
||||
if deleted {
|
||||
t.Error("Expected deleted to be false for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Delete_PanicOnNegativeRefs(t *testing.T) {
|
||||
// This test demonstrates the panic condition by manipulating
|
||||
// the ref count directly to create an invalid state
|
||||
pool := NewUsagePool()
|
||||
key := "test-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
// Store the value to get it in the pool
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
// Get the pool value to manipulate its refs directly
|
||||
pool.Lock()
|
||||
upv, exists := pool.pool[key]
|
||||
if !exists {
|
||||
pool.Unlock()
|
||||
t.Fatal("Value should exist in pool")
|
||||
}
|
||||
|
||||
// Manually set refs to 1 to test the panic condition
|
||||
atomic.StoreInt32(&upv.refs, 1)
|
||||
pool.Unlock()
|
||||
|
||||
// Now delete twice - the second delete should cause refs to go negative
|
||||
// First delete
|
||||
deleted1, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("First delete failed: %v", err)
|
||||
}
|
||||
if !deleted1 {
|
||||
t.Error("First delete should have removed the value")
|
||||
}
|
||||
|
||||
// Second delete on the same key after it was removed should be safe
|
||||
deleted2, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Errorf("Second delete should not error: %v", err)
|
||||
}
|
||||
if deleted2 {
|
||||
t.Error("Second delete should return false for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add multiple values
|
||||
values := map[string]string{
|
||||
"key1": "value1",
|
||||
"key2": "value2",
|
||||
"key3": "value3",
|
||||
}
|
||||
|
||||
for key, value := range values {
|
||||
pool.LoadOrStore(key, &mockDestructor{value: value})
|
||||
}
|
||||
|
||||
// Range through all values
|
||||
found := make(map[string]string)
|
||||
pool.Range(func(key, value any) bool {
|
||||
found[key.(string)] = value.(*mockDestructor).value
|
||||
return true
|
||||
})
|
||||
|
||||
if len(found) != len(values) {
|
||||
t.Errorf("Expected %d values, got %d", len(values), len(found))
|
||||
}
|
||||
|
||||
for key, expectedValue := range values {
|
||||
if actualValue, exists := found[key]; !exists || actualValue != expectedValue {
|
||||
t.Errorf("Key %s: expected '%s', got '%s'", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range_EarlyReturn(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add multiple values
|
||||
for i := 0; i < 5; i++ {
|
||||
pool.LoadOrStore(i, &mockDestructor{value: "value"})
|
||||
}
|
||||
|
||||
// Range but return false after first iteration
|
||||
count := 0
|
||||
pool.Range(func(key, value any) bool {
|
||||
count++
|
||||
return false // Stop after first iteration
|
||||
})
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 iteration, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Concurrent_LoadOrNew(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "concurrent-key"
|
||||
constructorCalls := int32(0)
|
||||
|
||||
const numGoroutines = 100
|
||||
var wg sync.WaitGroup
|
||||
results := make([]any, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
val, _, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
atomic.AddInt32(&constructorCalls, 1)
|
||||
// Add small delay to increase chance of race conditions
|
||||
time.Sleep(time.Microsecond)
|
||||
return &mockDestructor{value: "concurrent-value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected error: %v", index, err)
|
||||
return
|
||||
}
|
||||
results[index] = val
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Constructor should only be called once
|
||||
if calls := atomic.LoadInt32(&constructorCalls); calls != 1 {
|
||||
t.Errorf("Expected constructor to be called once, was called %d times", calls)
|
||||
}
|
||||
|
||||
// All goroutines should get the same value
|
||||
firstVal := results[0]
|
||||
for i, val := range results {
|
||||
if val != firstVal {
|
||||
t.Errorf("Goroutine %d got different value than first goroutine", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Reference count should equal number of goroutines
|
||||
refs, exists := pool.References(key)
|
||||
if !exists {
|
||||
t.Error("Key should exist in pool")
|
||||
}
|
||||
if refs != numGoroutines {
|
||||
t.Errorf("Expected %d references, got %d", numGoroutines, refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Concurrent_Delete(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "concurrent-delete-key"
|
||||
mockVal := &mockDestructor{value: "test-value"}
|
||||
|
||||
const numRefs = 50
|
||||
|
||||
// Add multiple references
|
||||
for i := 0; i < numRefs; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
deleteResults := make([]bool, numRefs)
|
||||
|
||||
// Delete concurrently
|
||||
for i := 0; i < numRefs; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Errorf("Goroutine %d: Unexpected error: %v", index, err)
|
||||
return
|
||||
}
|
||||
deleteResults[index] = deleted
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Exactly one delete should have returned true (when refs reached 0)
|
||||
deletedCount := 0
|
||||
for _, deleted := range deleteResults {
|
||||
if deleted {
|
||||
deletedCount++
|
||||
}
|
||||
}
|
||||
if deletedCount != 1 {
|
||||
t.Errorf("Expected exactly 1 delete to return true, got %d", deletedCount)
|
||||
}
|
||||
|
||||
// Value should be destroyed
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed after all references deleted")
|
||||
}
|
||||
|
||||
// Key should not exist
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after all references deleted")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_DestructorError(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "destructor-error-key"
|
||||
expectedErr := errors.New("destructor failed")
|
||||
mockVal := &mockDestructor{value: "test-value", err: expectedErr}
|
||||
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != expectedErr {
|
||||
t.Errorf("Expected destructor error, got: %v", err)
|
||||
}
|
||||
if !deleted {
|
||||
t.Error("Expected deleted to be true even with destructor error")
|
||||
}
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Destructor should have been called despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Mixed_Concurrent_Operations(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
keys := []string{"key1", "key2", "key3"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const opsPerKey = 10
|
||||
|
||||
// Test concurrent operations but with more controlled behavior
|
||||
for _, key := range keys {
|
||||
for i := 0; i < opsPerKey; i++ {
|
||||
wg.Add(2) // LoadOrStore and Delete
|
||||
|
||||
// LoadOrStore (safer than LoadOrNew for concurrency)
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
pool.LoadOrStore(k, &mockDestructor{value: k + "-value"})
|
||||
}(key)
|
||||
|
||||
// Delete (may fail if refs are 0, that's fine)
|
||||
go func(k string) {
|
||||
defer wg.Done()
|
||||
pool.Delete(k)
|
||||
}(key)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Test that the pool is in a consistent state
|
||||
for _, key := range keys {
|
||||
refs, exists := pool.References(key)
|
||||
if exists && refs < 0 {
|
||||
t.Errorf("Key %s has negative reference count: %d", key, refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_Range_SkipsErrorValues(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
|
||||
// Add value that will succeed
|
||||
goodKey := "good-key"
|
||||
pool.LoadOrStore(goodKey, &mockDestructor{value: "good-value"})
|
||||
|
||||
// Try to add value that will fail construction
|
||||
badKey := "bad-key"
|
||||
pool.LoadOrNew(badKey, func() (Destructor, error) {
|
||||
return nil, errors.New("construction failed")
|
||||
})
|
||||
|
||||
// Range should only iterate good values
|
||||
count := 0
|
||||
pool.Range(func(key, value any) bool {
|
||||
count++
|
||||
if key.(string) != goodKey {
|
||||
t.Errorf("Expected only good key, got: %s", key.(string))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Expected 1 value in range, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_LoadOrStore_ErrorRecovery(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "error-recovery-key"
|
||||
|
||||
// First, create a value that fails construction
|
||||
_, _, err := pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return nil, errors.New("construction failed")
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("Expected constructor error")
|
||||
}
|
||||
|
||||
// Now try LoadOrStore with a good value - should recover
|
||||
goodVal := &mockDestructor{value: "recovery-value"}
|
||||
val, loaded := pool.LoadOrStore(key, goodVal)
|
||||
if loaded {
|
||||
t.Error("Expected loaded to be false for error recovery")
|
||||
}
|
||||
if val != goodVal {
|
||||
t.Error("Expected recovery value to be returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_MemoryLeak_Prevention(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "memory-leak-test"
|
||||
|
||||
// Create many references
|
||||
const numRefs = 1000
|
||||
mockVal := &mockDestructor{value: "leak-test"}
|
||||
|
||||
for i := 0; i < numRefs; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
// Delete all references
|
||||
for i := 0; i < numRefs; i++ {
|
||||
deleted, err := pool.Delete(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
if i == numRefs-1 && !deleted {
|
||||
t.Error("Last delete should return true")
|
||||
} else if i < numRefs-1 && deleted {
|
||||
t.Errorf("Delete %d should return false", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify destructor was called
|
||||
if !mockVal.IsDestroyed() {
|
||||
t.Error("Value should be destroyed after all references deleted")
|
||||
}
|
||||
|
||||
// Verify no memory leak - key should be removed from map
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
t.Error("Key should not exist after complete deletion")
|
||||
}
|
||||
if refs != 0 {
|
||||
t.Errorf("Expected 0 references, got %d", refs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsagePool_RaceCondition_RefsCounter(t *testing.T) {
|
||||
pool := NewUsagePool()
|
||||
key := "race-test-key"
|
||||
mockVal := &mockDestructor{value: "race-value"}
|
||||
|
||||
const numOperations = 100
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Mix of increment and decrement operations
|
||||
for i := 0; i < numOperations; i++ {
|
||||
wg.Add(2)
|
||||
|
||||
// Increment (LoadOrStore)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}()
|
||||
|
||||
// Decrement (Delete) - may fail if refs are 0, that's ok
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.Delete(key)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Final reference count should be consistent
|
||||
refs, exists := pool.References(key)
|
||||
if exists {
|
||||
if refs < 0 {
|
||||
t.Errorf("Reference count should never be negative, got: %d", refs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_LoadOrNew(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrNew(key, func() (Destructor, error) {
|
||||
return &mockDestructor{value: "bench-value"}, nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_LoadOrStore(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
mockVal := &mockDestructor{value: "bench-value"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUsagePool_Delete(b *testing.B) {
|
||||
pool := NewUsagePool()
|
||||
key := "bench-key"
|
||||
mockVal := &mockDestructor{value: "bench-value"}
|
||||
|
||||
// Pre-populate with many references
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.LoadOrStore(key, mockVal)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
pool.Delete(key)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue