diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 086e5b3a8f5..4919d6ad370 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -12,6 +12,7 @@ import ( "math" "math/bits" "slices" + "strings" ) type branch int @@ -132,7 +133,7 @@ type limit struct { } func (l limit) String() string { - return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax) + return fmt.Sprintf("sm,SM=%d,%d um,UM=%d,%d", l.min, l.max, l.umin, l.umax) } func (l limit) intersect(l2 limit) limit { @@ -1965,6 +1966,30 @@ func (ft *factsTable) flowLimit(v *Value) bool { b := ft.limits[v.Args[1].ID] bitsize := uint(v.Type.Size()) * 8 return ft.newLimit(v, a.mul(b.exp2(bitsize), bitsize)) + case OpRsh64x64, OpRsh64x32, OpRsh64x16, OpRsh64x8, + OpRsh32x64, OpRsh32x32, OpRsh32x16, OpRsh32x8, + OpRsh16x64, OpRsh16x32, OpRsh16x16, OpRsh16x8, + OpRsh8x64, OpRsh8x32, OpRsh8x16, OpRsh8x8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + if b.min >= 0 { + // Shift of negative makes a value closer to 0 (greater), + // so if a.min is negative, v.min is a.min>>b.min instead of a.min>>b.max, + // and similarly if a.max is negative, v.max is a.max>>b.max. + // Easier to compute min and max of both than to write sign logic. + vmin := min(a.min>>b.min, a.min>>b.max) + vmax := max(a.max>>b.min, a.max>>b.max) + return ft.signedMinMax(v, vmin, vmax) + } + case OpRsh64Ux64, OpRsh64Ux32, OpRsh64Ux16, OpRsh64Ux8, + OpRsh32Ux64, OpRsh32Ux32, OpRsh32Ux16, OpRsh32Ux8, + OpRsh16Ux64, OpRsh16Ux32, OpRsh16Ux16, OpRsh16Ux8, + OpRsh8Ux64, OpRsh8Ux32, OpRsh8Ux16, OpRsh8Ux8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + if b.min >= 0 { + return ft.unsignedMinMax(v, a.umin>>b.max, a.umax>>b.min) + } case OpDiv64, OpDiv32, OpDiv16, OpDiv8: a := ft.limits[v.Args[0].ID] b := ft.limits[v.Args[1].ID] @@ -2621,6 +2646,17 @@ var bytesizeToAnd = [...]Op{ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { for iv, v := range b.Values { switch v.Op { + case OpStaticLECall: + if b.Func.pass.debug > 0 && len(v.Args) == 2 { + fn := auxToCall(v.Aux).Fn + if fn != nil && strings.Contains(fn.String(), "prove") { + // Print bounds of any argument to single-arg function with "prove" in name, + // for debugging and especially for test/prove.go. + // (v.Args[1] is mem). + x := v.Args[0] + b.Func.Warnl(v.Pos, "Proved %v (%v)", ft.limits[x.ID], x) + } + } case OpSlicemask: // Replace OpSlicemask operations in b with constants where possible. cap := v.Args[0] @@ -2670,21 +2706,8 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64, OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64, OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64, - OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64: - // Check whether, for a >> b, we know that a is non-negative - // and b is all of a's bits except the MSB. If so, a is shifted to zero. - bits := 8 * v.Args[0].Type.Size() - if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) { - if b.Func.pass.debug > 0 { - b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op) - } - v.reset(bytesizeToConst[bits/8]) - v.AuxInt = 0 - break // Be sure not to fallthrough - this is no longer OpRsh. - } - // If the Rsh hasn't been replaced with 0, still check if it is bounded. - fallthrough - case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64, + OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64, + OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64, OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64, OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64, OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64, diff --git a/test/prove.go b/test/prove.go index db32d1beb0d..365e8ba006e 100644 --- a/test/prove.go +++ b/test/prove.go @@ -971,40 +971,6 @@ func negIndex2(n int) { useSlice(c) } -// Check that prove is zeroing these right shifts of positive ints by bit-width - 1. -// e.g (Rsh64x64 n (Const64 [63])) && ft.isNonNegative(n) -> 0 -func sh64(n int64) int64 { - if n < 0 { - return n - } - return n >> 63 // ERROR "Proved Rsh64x64 shifts to zero" -} - -func sh32(n int32) int32 { - if n < 0 { - return n - } - return n >> 31 // ERROR "Proved Rsh32x64 shifts to zero" -} - -func sh32x64(n int32) int32 { - if n < 0 { - return n - } - return n >> uint64(31) // ERROR "Proved Rsh32x64 shifts to zero" -} - -func sh16(n int16) int16 { - if n < 0 { - return n - } - return n >> 15 // ERROR "Proved Rsh16x64 shifts to zero" -} - -func sh64noopt(n int64) int64 { - return n >> 63 // not optimized; n could be negative -} - // These cases are division of a positive signed integer by a power of 2. // The opt pass doesnt have sufficient information to see that n is positive. // So, instead, opt rewrites the division with a less-than-optimal replacement. @@ -2584,6 +2550,103 @@ func swapbound(v []int) { } } +func rightshift(v *[256]int) int { + for i := range 1024 { // ERROR "Induction" + if v[i/32] == 0 { // ERROR "Proved Div64 is unsigned" "Proved IsInBounds" + return i + } + } + for i := range 1024 { // ERROR "Induction" + if v[i>>2] == 0 { // ERROR "Proved IsInBounds" + return i + } + } + return -1 +} + +func rightShiftBounds(v, s int) { + // The ignored "Proved" messages on the shift itself are about whether s >= 0 or s < 32 or 64. + // We care about the bounds for x printed on the prove(x) lines. + + if -8 <= v && v <= -2 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-4,-1 " + } + if -80 <= v && v <= -20 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-40,-3 " + } + if -8 <= v && v <= 10 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-4,5 " + } + if 2 <= v && v <= 10 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=0,5 " + } + + if -8 <= v && v <= -2 && 0 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-8,-1 " + } + if -80 <= v && v <= -20 && 0 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-80,-3 " + } + if -8 <= v && v <= 10 && 0 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-8,10 " + } + if 2 <= v && v <= 10 && 0 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=0,10 " + } + + if -8 <= v && v <= -2 && -1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-8,-1 " + } + if -80 <= v && v <= -20 && -1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-80,-3 " + } + if -8 <= v && v <= 10 && -1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=-8,10 " + } + if 2 <= v && v <= 10 && -1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + prove(x) // ERROR "Proved sm,SM=0,10 " + } +} + +func unsignedRightShiftBounds(v uint, s int) { + if 2 <= v && v <= 10 && -1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + proveu(x) // ERROR "Proved sm,SM=0,10 " + } + if 2 <= v && v <= 10 && 0 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + proveu(x) // ERROR "Proved sm,SM=0,10 " + } + if 2 <= v && v <= 10 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + proveu(x) // ERROR "Proved sm,SM=0,5 " + } + if 20 <= v && v <= 100 && 1 <= s && s <= 3 { + x := v>>s // ERROR "Proved" + proveu(x) // ERROR "Proved sm,SM=2,50 " + } +} + +//go:noinline +func prove(x int) { +} + +//go:noinline +func proveu(x uint) { +} + //go:noinline func useInt(a int) { } diff --git a/test/prove_constant_folding.go b/test/prove_constant_folding.go index 1029c8e2d3a..46764f9b9d9 100644 --- a/test/prove_constant_folding.go +++ b/test/prove_constant_folding.go @@ -20,14 +20,62 @@ func f0i(x int) int { return x + 1 } -func f0u(x uint) uint { +func f0u(x uint) int { if x == 20 { - return x // ERROR "Proved.+is constant 20$" + return int(x) // ERROR "Proved.+is constant 20$" } if (x + 20) == 20 { - return x + 5 // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w" + return int(x + 5) // ERROR "Proved.+is constant 0$" "Proved.+is constant 5$" "x\+d >=? w" } - return x + 1 + if x < 1000 { + return int(x)>>31 // ERROR "Proved.+is constant 0$" + } + if x := int32(x); x < -1000 { + return int(x>>31) // ERROR "Proved.+is constant -1$" + } + + return int(x) + 1 +} + +// Check that prove is zeroing these right shifts of positive ints by bit-width - 1. +// e.g (Rsh64x64 n (Const64 [63])) && ft.isNonNegative(n) -> 0 +func sh64(n int64) int64 { + if n < 0 { + return n + } + return n >> 63 // ERROR "Proved .+ is constant 0$" +} + +func sh32(n int32) int32 { + if n < 0 { + return n + } + return n >> 31 // ERROR "Proved .+ is constant 0$" +} + +func sh32x64(n int32) int32 { + if n < 0 { + return n + } + return n >> uint64(31) // ERROR "Proved .+ is constant 0$" +} + +func sh32x64n(n int32) int32 { + if n >= 0 { + return 0 + } + return n >> 31// ERROR "Proved .+ is constant -1$" +} + +func sh16(n int16) int16 { + if n < 0 { + return n + } + return n >> 15 // ERROR "Proved .+ is constant 0$" +} + +func sh64noopt(n int64) int64 { + return n >> 63 // not optimized; n could be negative }