mirror of
https://github.com/golang/go.git
synced 2025-12-08 06:10:04 +00:00
math/big: wrap Float.Cmp result in struct to prevent wrong use
Float.Cmp used to return a value < 0, 0, or > 0 depending on how arguments x, y compared against each other. With the possibility of NaNs, the result was changed into an Accuracy (to include Undef). Consequently, Float.Cmp results could still be compared for (in-) equality with 0, but comparing if < 0 or > 0 would provide the wrong answer w/o any obvious notice by the compiler. This change wraps Float.Cmp results into a struct and accessors are used to access the desired result. This prevents incorrect use. Change-Id: I34e6a6c1859251ec99b5cf953e82542025ace56f Reviewed-on: https://go-review.googlesource.com/7526 Reviewed-by: Rob Pike <r@golang.org>
This commit is contained in:
parent
b100216441
commit
23fd374bf2
4 changed files with 30 additions and 25 deletions
|
|
@ -1446,6 +1446,10 @@ func (z *Float) Quo(x, y *Float) *Float {
|
||||||
return z
|
return z
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type cmpResult struct {
|
||||||
|
acc Accuracy
|
||||||
|
}
|
||||||
|
|
||||||
// Cmp compares x and y and returns:
|
// Cmp compares x and y and returns:
|
||||||
//
|
//
|
||||||
// Below if x < y
|
// Below if x < y
|
||||||
|
|
@ -1453,44 +1457,45 @@ func (z *Float) Quo(x, y *Float) *Float {
|
||||||
// Above if x > y
|
// Above if x > y
|
||||||
// Undef if any of x, y is NaN
|
// Undef if any of x, y is NaN
|
||||||
//
|
//
|
||||||
func (x *Float) Cmp(y *Float) Accuracy {
|
func (x *Float) Cmp(y *Float) cmpResult {
|
||||||
if debugFloat {
|
if debugFloat {
|
||||||
x.validate()
|
x.validate()
|
||||||
y.validate()
|
y.validate()
|
||||||
}
|
}
|
||||||
|
|
||||||
if x.form == nan || y.form == nan {
|
if x.form == nan || y.form == nan {
|
||||||
return Undef
|
return cmpResult{Undef}
|
||||||
}
|
}
|
||||||
|
|
||||||
mx := x.ord()
|
mx := x.ord()
|
||||||
my := y.ord()
|
my := y.ord()
|
||||||
switch {
|
switch {
|
||||||
case mx < my:
|
case mx < my:
|
||||||
return Below
|
return cmpResult{Below}
|
||||||
case mx > my:
|
case mx > my:
|
||||||
return Above
|
return cmpResult{Above}
|
||||||
}
|
}
|
||||||
// mx == my
|
// mx == my
|
||||||
|
|
||||||
// only if |mx| == 1 we have to compare the mantissae
|
// only if |mx| == 1 we have to compare the mantissae
|
||||||
switch mx {
|
switch mx {
|
||||||
case -1:
|
case -1:
|
||||||
return y.ucmp(x)
|
return cmpResult{y.ucmp(x)}
|
||||||
case +1:
|
case +1:
|
||||||
return x.ucmp(y)
|
return cmpResult{x.ucmp(y)}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Exact
|
return cmpResult{Exact}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The following accessors simplify testing of Cmp results.
|
// The following accessors simplify testing of Cmp results.
|
||||||
func (acc Accuracy) Eql() bool { return acc == Exact }
|
func (res cmpResult) Acc() Accuracy { return res.acc }
|
||||||
func (acc Accuracy) Neq() bool { return acc != Exact }
|
func (res cmpResult) Eql() bool { return res.acc == Exact }
|
||||||
func (acc Accuracy) Lss() bool { return acc == Below }
|
func (res cmpResult) Neq() bool { return res.acc != Exact }
|
||||||
func (acc Accuracy) Leq() bool { return acc&Above == 0 }
|
func (res cmpResult) Lss() bool { return res.acc == Below }
|
||||||
func (acc Accuracy) Gtr() bool { return acc == Above }
|
func (res cmpResult) Leq() bool { return res.acc&Above == 0 }
|
||||||
func (acc Accuracy) Geq() bool { return acc&Below == 0 }
|
func (res cmpResult) Gtr() bool { return res.acc == Above }
|
||||||
|
func (res cmpResult) Geq() bool { return res.acc&Below == 0 }
|
||||||
|
|
||||||
// ord classifies x and returns:
|
// ord classifies x and returns:
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -207,7 +207,7 @@ func feq(x, y *Float) bool {
|
||||||
if x.IsNaN() || y.IsNaN() {
|
if x.IsNaN() || y.IsNaN() {
|
||||||
return x.IsNaN() && y.IsNaN()
|
return x.IsNaN() && y.IsNaN()
|
||||||
}
|
}
|
||||||
return x.Cmp(y) == 0 && x.IsNeg() == y.IsNeg()
|
return x.Cmp(y).Eql() && x.IsNeg() == y.IsNeg()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFloatMantExp(t *testing.T) {
|
func TestFloatMantExp(t *testing.T) {
|
||||||
|
|
@ -918,7 +918,7 @@ func TestFloatRat(t *testing.T) {
|
||||||
// inverse conversion
|
// inverse conversion
|
||||||
if res != nil {
|
if res != nil {
|
||||||
got := new(Float).SetPrec(64).SetRat(res)
|
got := new(Float).SetPrec(64).SetRat(res)
|
||||||
if got.Cmp(x) != 0 {
|
if got.Cmp(x).Neq() {
|
||||||
t.Errorf("%s: got %s; want %s", test.x, got, x)
|
t.Errorf("%s: got %s; want %s", test.x, got, x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -995,7 +995,7 @@ func TestFloatInc(t *testing.T) {
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
x.Add(&x, &one)
|
x.Add(&x, &one)
|
||||||
}
|
}
|
||||||
if x.Cmp(new(Float).SetInt64(n)) != 0 {
|
if x.Cmp(new(Float).SetInt64(n)).Neq() {
|
||||||
t.Errorf("prec = %d: got %s; want %d", prec, &x, n)
|
t.Errorf("prec = %d: got %s; want %d", prec, &x, n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1036,14 +1036,14 @@ func TestFloatAdd(t *testing.T) {
|
||||||
got := new(Float).SetPrec(prec).SetMode(mode)
|
got := new(Float).SetPrec(prec).SetMode(mode)
|
||||||
got.Add(x, y)
|
got.Add(x, y)
|
||||||
want := zbits.round(prec, mode)
|
want := zbits.round(prec, mode)
|
||||||
if got.Cmp(want) != 0 {
|
if got.Cmp(want).Neq() {
|
||||||
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t+ %s %v\n\t= %s\n\twant %s",
|
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t+ %s %v\n\t= %s\n\twant %s",
|
||||||
i, prec, mode, x, xbits, y, ybits, got, want)
|
i, prec, mode, x, xbits, y, ybits, got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
got.Sub(z, x)
|
got.Sub(z, x)
|
||||||
want = ybits.round(prec, mode)
|
want = ybits.round(prec, mode)
|
||||||
if got.Cmp(want) != 0 {
|
if got.Cmp(want).Neq() {
|
||||||
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t- %s %v\n\t= %s\n\twant %s",
|
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t- %s %v\n\t= %s\n\twant %s",
|
||||||
i, prec, mode, z, zbits, x, xbits, got, want)
|
i, prec, mode, z, zbits, x, xbits, got, want)
|
||||||
}
|
}
|
||||||
|
|
@ -1137,7 +1137,7 @@ func TestFloatMul(t *testing.T) {
|
||||||
got := new(Float).SetPrec(prec).SetMode(mode)
|
got := new(Float).SetPrec(prec).SetMode(mode)
|
||||||
got.Mul(x, y)
|
got.Mul(x, y)
|
||||||
want := zbits.round(prec, mode)
|
want := zbits.round(prec, mode)
|
||||||
if got.Cmp(want) != 0 {
|
if got.Cmp(want).Neq() {
|
||||||
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t* %s %v\n\t= %s\n\twant %s",
|
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t* %s %v\n\t= %s\n\twant %s",
|
||||||
i, prec, mode, x, xbits, y, ybits, got, want)
|
i, prec, mode, x, xbits, y, ybits, got, want)
|
||||||
}
|
}
|
||||||
|
|
@ -1147,7 +1147,7 @@ func TestFloatMul(t *testing.T) {
|
||||||
}
|
}
|
||||||
got.Quo(z, x)
|
got.Quo(z, x)
|
||||||
want = ybits.round(prec, mode)
|
want = ybits.round(prec, mode)
|
||||||
if got.Cmp(want) != 0 {
|
if got.Cmp(want).Neq() {
|
||||||
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t/ %s %v\n\t= %s\n\twant %s",
|
t.Errorf("i = %d, prec = %d, %s:\n\t %s %v\n\t/ %s %v\n\t= %s\n\twant %s",
|
||||||
i, prec, mode, z, zbits, x, xbits, got, want)
|
i, prec, mode, z, zbits, x, xbits, got, want)
|
||||||
}
|
}
|
||||||
|
|
@ -1230,7 +1230,7 @@ func TestIssue6866(t *testing.T) {
|
||||||
p.Mul(p, psix)
|
p.Mul(p, psix)
|
||||||
z2.Sub(two, p)
|
z2.Sub(two, p)
|
||||||
|
|
||||||
if z1.Cmp(z2) != 0 {
|
if z1.Cmp(z2).Neq() {
|
||||||
t.Fatalf("prec %d: got z1 = %s != z2 = %s; want z1 == z2\n", prec, z1, z2)
|
t.Fatalf("prec %d: got z1 = %s != z2 = %s; want z1 == z2\n", prec, z1, z2)
|
||||||
}
|
}
|
||||||
if z1.Sign() != 0 {
|
if z1.Sign() != 0 {
|
||||||
|
|
@ -1281,7 +1281,7 @@ func TestFloatQuo(t *testing.T) {
|
||||||
prec := uint(preci + d)
|
prec := uint(preci + d)
|
||||||
got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y)
|
got := new(Float).SetPrec(prec).SetMode(mode).Quo(x, y)
|
||||||
want := bits.round(prec, mode)
|
want := bits.round(prec, mode)
|
||||||
if got.Cmp(want) != 0 {
|
if got.Cmp(want).Neq() {
|
||||||
t.Errorf("i = %d, prec = %d, %s:\n\t %s\n\t/ %s\n\t= %s\n\twant %s",
|
t.Errorf("i = %d, prec = %d, %s:\n\t %s\n\t/ %s\n\t= %s\n\twant %s",
|
||||||
i, prec, mode, x, y, got, want)
|
i, prec, mode, x, y, got, want)
|
||||||
}
|
}
|
||||||
|
|
@ -1462,7 +1462,7 @@ func TestFloatCmpSpecialValues(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, y := range args {
|
for _, y := range args {
|
||||||
yy.SetFloat64(y)
|
yy.SetFloat64(y)
|
||||||
got := xx.Cmp(yy)
|
got := xx.Cmp(yy).Acc()
|
||||||
want := Undef
|
want := Undef
|
||||||
switch {
|
switch {
|
||||||
case x < y:
|
case x < y:
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ func TestFloatSetFloat64String(t *testing.T) {
|
||||||
}
|
}
|
||||||
f, _ := x.Float64()
|
f, _ := x.Float64()
|
||||||
want := new(Float).SetFloat64(test.x)
|
want := new(Float).SetFloat64(test.x)
|
||||||
if x.Cmp(want) != 0 {
|
if x.Cmp(want).Neq() {
|
||||||
t.Errorf("%s: got %s (%v); want %v", test.s, &x, f, test.x)
|
t.Errorf("%s: got %s (%v); want %v", test.s, &x, f, test.x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,7 +63,7 @@ func ExampleFloat_Cmp() {
|
||||||
t := x.Cmp(y)
|
t := x.Cmp(y)
|
||||||
fmt.Printf(
|
fmt.Printf(
|
||||||
"%4s %4s %5s %c %c %c %c %c %c\n",
|
"%4s %4s %5s %c %c %c %c %c %c\n",
|
||||||
x, y, t,
|
x, y, t.Acc(),
|
||||||
mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
|
mark(t.Eql()), mark(t.Neq()), mark(t.Lss()), mark(t.Leq()), mark(t.Gtr()), mark(t.Geq()))
|
||||||
}
|
}
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue