diff --git a/src/testing/testing.go b/src/testing/testing.go index 466dd96981a..fc52f3c547e 100644 --- a/src/testing/testing.go +++ b/src/testing/testing.go @@ -667,6 +667,7 @@ var _ TB = (*B)(nil) type T struct { common isParallel bool + isEnvSet bool context *testContext // For running tests and subtests. } @@ -964,6 +965,29 @@ func (c *common) TempDir() string { return dir } +// Setenv calls os.Setenv(key, value) and uses Cleanup to +// restore the environment variable to its original value +// after the test. +// +// This cannot be used in parallel tests. +func (c *common) Setenv(key, value string) { + prevValue, ok := os.LookupEnv(key) + + if err := os.Setenv(key, value); err != nil { + c.Fatalf("cannot set environment variable: %v", err) + } + + if ok { + c.Cleanup(func() { + os.Setenv(key, prevValue) + }) + } else { + c.Cleanup(func() { + os.Unsetenv(key) + }) + } +} + // panicHanding is an argument to runCleanup. type panicHandling int @@ -1035,6 +1059,9 @@ func (t *T) Parallel() { if t.isParallel { panic("testing: t.Parallel called multiple times") } + if t.isEnvSet { + panic("testing: t.Parallel called after t.Setenv; cannot set environment variables in parallel tests") + } t.isParallel = true // We don't want to include the time we spend waiting for serial tests @@ -1068,6 +1095,21 @@ func (t *T) Parallel() { t.raceErrors += -race.Errors() } +// Setenv calls os.Setenv(key, value) and uses Cleanup to +// restore the environment variable to its original value +// after the test. +// +// This cannot be used in parallel tests. +func (t *T) Setenv(key, value string) { + if t.isParallel { + panic("testing: t.Setenv called after t.Parallel; cannot set environment variables in parallel tests") + } + + t.isEnvSet = true + + t.common.Setenv(key, value) +} + // InternalTest is an internal type but exported because it is cross-package; // it is part of the implementation of the "go test" command. type InternalTest struct { diff --git a/src/testing/testing_test.go b/src/testing/testing_test.go index 0f096980ca4..55a4df4739e 100644 --- a/src/testing/testing_test.go +++ b/src/testing/testing_test.go @@ -109,3 +109,86 @@ func testTempDir(t *testing.T) { t.Errorf("unexpected %d files in TempDir: %v", len(files), files) } } + +func TestSetenv(t *testing.T) { + tests := []struct { + name string + key string + initialValueExists bool + initialValue string + newValue string + }{ + { + name: "initial value exists", + key: "GO_TEST_KEY_1", + initialValueExists: true, + initialValue: "111", + newValue: "222", + }, + { + name: "initial value exists but empty", + key: "GO_TEST_KEY_2", + initialValueExists: true, + initialValue: "", + newValue: "222", + }, + { + name: "initial value is not exists", + key: "GO_TEST_KEY_3", + initialValueExists: false, + initialValue: "", + newValue: "222", + }, + } + + for _, test := range tests { + if test.initialValueExists { + if err := os.Setenv(test.key, test.initialValue); err != nil { + t.Fatalf("unable to set env: got %v", err) + } + } else { + os.Unsetenv(test.key) + } + + t.Run(test.name, func(t *testing.T) { + t.Setenv(test.key, test.newValue) + if os.Getenv(test.key) != test.newValue { + t.Fatalf("unexpected value after t.Setenv: got %s, want %s", os.Getenv(test.key), test.newValue) + } + }) + + got, exists := os.LookupEnv(test.key) + if got != test.initialValue { + t.Fatalf("unexpected value after t.Setenv cleanup: got %s, want %s", got, test.initialValue) + } + if exists != test.initialValueExists { + t.Fatalf("unexpected value after t.Setenv cleanup: got %t, want %t", exists, test.initialValueExists) + } + } +} + +func TestSetenvWithParallelAfterSetenv(t *testing.T) { + defer func() { + want := "testing: t.Parallel called after t.Setenv; cannot set environment variables in parallel tests" + if got := recover(); got != want { + t.Fatalf("expected panic; got %#v want %q", got, want) + } + }() + + t.Setenv("GO_TEST_KEY_1", "value") + + t.Parallel() +} + +func TestSetenvWithParallelBeforeSetenv(t *testing.T) { + defer func() { + want := "testing: t.Setenv called after t.Parallel; cannot set environment variables in parallel tests" + if got := recover(); got != want { + t.Fatalf("expected panic; got %#v want %q", got, want) + } + }() + + t.Parallel() + + t.Setenv("GO_TEST_KEY_1", "value") +}