mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
flag: panic if flag name begins with - or contains =
Fixes #41792
Change-Id: I9b4aae8a899e3c3ac9532d27932d275cfb1fab48
GitHub-Last-Rev: f06b1e1767
GitHub-Pull-Request: golang/go#42737
Reviewed-on: https://go-review.googlesource.com/c/go/+/271788
Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com>
Reviewed-by: Rob Pike <r@golang.org>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Rob Pike <r@golang.org>
This commit is contained in:
parent
d0b79e3513
commit
5ce51ea741
2 changed files with 101 additions and 6 deletions
|
|
@ -857,17 +857,23 @@ func Func(name, usage string, fn func(string) error) {
|
||||||
// of strings by giving the slice the methods of Value; in particular, Set would
|
// of strings by giving the slice the methods of Value; in particular, Set would
|
||||||
// decompose the comma-separated string into the slice.
|
// decompose the comma-separated string into the slice.
|
||||||
func (f *FlagSet) Var(value Value, name string, usage string) {
|
func (f *FlagSet) Var(value Value, name string, usage string) {
|
||||||
|
// Flag must not begin "-" or contain "=".
|
||||||
|
if strings.HasPrefix(name, "-") {
|
||||||
|
panic(f.sprintf("flag %q begins with -", name))
|
||||||
|
} else if strings.Contains(name, "=") {
|
||||||
|
panic(f.sprintf("flag %q contains =", name))
|
||||||
|
}
|
||||||
|
|
||||||
// Remember the default value as a string; it won't change.
|
// Remember the default value as a string; it won't change.
|
||||||
flag := &Flag{name, usage, value, value.String()}
|
flag := &Flag{name, usage, value, value.String()}
|
||||||
_, alreadythere := f.formal[name]
|
_, alreadythere := f.formal[name]
|
||||||
if alreadythere {
|
if alreadythere {
|
||||||
var msg string
|
var msg string
|
||||||
if f.name == "" {
|
if f.name == "" {
|
||||||
msg = fmt.Sprintf("flag redefined: %s", name)
|
msg = f.sprintf("flag redefined: %s", name)
|
||||||
} else {
|
} else {
|
||||||
msg = fmt.Sprintf("%s flag redefined: %s", f.name, name)
|
msg = f.sprintf("%s flag redefined: %s", f.name, name)
|
||||||
}
|
}
|
||||||
fmt.Fprintln(f.Output(), msg)
|
|
||||||
panic(msg) // Happens only if flags are declared with identical names
|
panic(msg) // Happens only if flags are declared with identical names
|
||||||
}
|
}
|
||||||
if f.formal == nil {
|
if f.formal == nil {
|
||||||
|
|
@ -886,13 +892,19 @@ func Var(value Value, name string, usage string) {
|
||||||
CommandLine.Var(value, name, usage)
|
CommandLine.Var(value, name, usage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sprintf formats the message, prints it to output, and returns it.
|
||||||
|
func (f *FlagSet) sprintf(format string, a ...interface{}) string {
|
||||||
|
msg := fmt.Sprintf(format, a...)
|
||||||
|
fmt.Fprintln(f.Output(), msg)
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
// failf prints to standard error a formatted error and usage message and
|
// failf prints to standard error a formatted error and usage message and
|
||||||
// returns the error.
|
// returns the error.
|
||||||
func (f *FlagSet) failf(format string, a ...interface{}) error {
|
func (f *FlagSet) failf(format string, a ...interface{}) error {
|
||||||
err := fmt.Errorf(format, a...)
|
msg := f.sprintf(format, a...)
|
||||||
fmt.Fprintln(f.Output(), err)
|
|
||||||
f.usage()
|
f.usage()
|
||||||
return err
|
return errors.New(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// usage calls the Usage method for the flag set if one is specified,
|
// usage calls the Usage method for the flag set if one is specified,
|
||||||
|
|
|
||||||
|
|
@ -655,3 +655,86 @@ func TestExitCode(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustPanic(t *testing.T, testName string, expected string, f func()) {
|
||||||
|
t.Helper()
|
||||||
|
defer func() {
|
||||||
|
switch msg := recover().(type) {
|
||||||
|
case nil:
|
||||||
|
t.Errorf("%s\n: expected panic(%q), but did not panic", testName, expected)
|
||||||
|
case string:
|
||||||
|
if msg != expected {
|
||||||
|
t.Errorf("%s\n: expected panic(%q), but got panic(%q)", testName, expected, msg)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Errorf("%s\n: expected panic(%q), but got panic(%T%v)", testName, expected, msg, msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidFlags(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
flag string
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
flag: "-foo",
|
||||||
|
errorMsg: "flag \"-foo\" begins with -",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flag: "foo=bar",
|
||||||
|
errorMsg: "flag \"foo=bar\" contains =",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
testName := fmt.Sprintf("FlagSet.Var(&v, %q, \"\")", test.flag)
|
||||||
|
|
||||||
|
fs := NewFlagSet("", ContinueOnError)
|
||||||
|
buf := bytes.NewBuffer(nil)
|
||||||
|
fs.SetOutput(buf)
|
||||||
|
|
||||||
|
mustPanic(t, testName, test.errorMsg, func() {
|
||||||
|
var v flagVar
|
||||||
|
fs.Var(&v, test.flag, "")
|
||||||
|
})
|
||||||
|
if msg := test.errorMsg + "\n"; msg != buf.String() {
|
||||||
|
t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedefinedFlags(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
flagSetName string
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
flagSetName: "",
|
||||||
|
errorMsg: "flag redefined: foo",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
flagSetName: "fs",
|
||||||
|
errorMsg: "fs flag redefined: foo",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
testName := fmt.Sprintf("flag redefined in FlagSet(%q)", test.flagSetName)
|
||||||
|
|
||||||
|
fs := NewFlagSet(test.flagSetName, ContinueOnError)
|
||||||
|
buf := bytes.NewBuffer(nil)
|
||||||
|
fs.SetOutput(buf)
|
||||||
|
|
||||||
|
var v flagVar
|
||||||
|
fs.Var(&v, "foo", "")
|
||||||
|
|
||||||
|
mustPanic(t, testName, test.errorMsg, func() {
|
||||||
|
fs.Var(&v, "foo", "")
|
||||||
|
})
|
||||||
|
if msg := test.errorMsg + "\n"; msg != buf.String() {
|
||||||
|
t.Errorf("%s\n: unexpected output: expected %q, bug got %q", testName, msg, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue