image/jpeg: correct and test reference slowFDCT and slowIDCT

The reference implementations slowFDCT and slowIDCT were not
rounding correctly, making the test not as good as it could be.
Before, the real implementations were required to always produce
values within ±2 of the reference; now, with no changes,
the real implementations produce values within ±1 of the (corrected)
reference.

Also tighten the test to return an error not just on a single value
exceeding tolerance but also on too many values at exactly that
tolerance.

Change-Id: I3dd6ca7582178fef972fb812d848f7a0158a6ed8
Reviewed-on: https://go-review.googlesource.com/c/go/+/705517
Auto-Submit: Russ Cox <rsc@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Alan Donovan <adonovan@google.com>
This commit is contained in:
Russ Cox 2025-09-19 08:36:11 -07:00 committed by Gopher Robot
parent 27c7bbc51c
commit 92e093467f
2 changed files with 220 additions and 52 deletions

View file

@ -7,6 +7,7 @@ package jpeg
import (
"fmt"
"math"
"math/big"
"math/rand"
"strings"
"testing"
@ -29,11 +30,37 @@ func BenchmarkIDCT(b *testing.B) {
benchmarkDCT(b, idct)
}
const testSlowVsBig = true
func TestDCT(t *testing.T) {
blocks := make([]block, len(testBlocks))
copy(blocks, testBlocks[:])
// Append some randomly generated blocks of varying sparseness.
// All zeros
blocks = append(blocks, block{})
// Every possible unit impulse.
for i := range blockSize {
var b block
b[i] = 255
blocks = append(blocks, b)
}
// All ones.
var ones block
for i := range ones {
ones[i] = 255
}
blocks = append(blocks, ones)
// Every possible inverted unit impulse.
for i := range blockSize {
ones[i] = 0
blocks = append(blocks, ones)
ones[i] = 255
}
// Some randomly generated blocks of varying sparseness.
r := rand.New(rand.NewSource(123))
for i := 0; i < 100; i++ {
b := block{}
@ -44,57 +71,84 @@ func TestDCT(t *testing.T) {
blocks = append(blocks, b)
}
// Check that the FDCT and IDCT functions are inverses, after a scale and
// level shift. Scaling reduces the rounding errors in the conversion from
// floats to ints.
for i, b := range blocks {
got, want := b, b
slowFDCT(&got)
slowIDCT(&got)
for j := range got {
got[j] = got[j]/8 + 128
}
if d := differ(&got, &want, 2); d >= 0 {
t.Errorf("i=%d: IDCT(FDCT) (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
// Check that the slow FDCT and IDCT functions are inverses,
// after a scale and level shift.
// Scaling reduces the rounding errors in the conversion.
// The “fast” ones are not inverses because the fast IDCT
// is optimized for 8-bit inputs, not full 16-bit ones.
slowRoundTrip := func(b *block) {
slowFDCT(b)
slowIDCT(b)
for j := range b {
b[j] = b[j]/8 + 128
}
}
nop := func(*block) {}
testDCT(t, "IDCT(FDCT)", blocks, slowRoundTrip, nop, 1, 8)
if testSlowVsBig {
testDCT(t, "slowFDCT", blocks, slowFDCT, slowerFDCT, 0, 64)
testDCT(t, "slowIDCT", blocks, slowIDCT, slowerIDCT, 0, 64)
}
// Check that the optimized and slow FDCT implementations agree.
// The fdct function already does a scale and level shift.
for i, b := range blocks {
got, want := b, b
fdct(&got)
slowFDCT(&want)
if d := differ(&got, &want, 2); d >= 0 {
t.Errorf("i=%d: FDCT (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
}
}
// Check that the optimized and slow IDCT implementations agree.
for i, b := range blocks {
got, want := b, b
idct(&got)
slowIDCT(&want)
if d := differ(&got, &want, 2); d >= 0 {
t.Errorf("i=%d: IDCT (diff at %d,%d)\nsrc\n%s\ngot\n%s\nwant\n%s\n", i, d/8, d%8, &b, &got, &want)
}
}
testDCT(t, "FDCT", blocks, fdct, slowFDCT, 1, 16)
testDCT(t, "IDCT", blocks, idct, slowIDCT, 1, 8)
}
// differ reports whether any pair-wise elements in b0 and b1 differ by more than 'ok'.
// That tolerance is because there isn't a single definitive decoding of
// a given JPEG image, even before the YCbCr to RGB conversion; implementations
func testDCT(t *testing.T, name string, blocks []block, fhave, fwant func(*block), tolerance int32, maxCloseCalls int) {
t.Run(name, func(t *testing.T) {
totalClose := 0
for i, b := range blocks {
have, want := b, b
fhave(&have)
fwant(&want)
d, n := differ(&have, &want, tolerance)
if d >= 0 || n > maxCloseCalls {
fail := ""
if d >= 0 {
fail = fmt.Sprintf("diff at %d,%d", d/8, d%8)
}
if n > maxCloseCalls {
if fail != "" {
fail += "; "
}
fail += fmt.Sprintf("%d close calls", n)
}
t.Errorf("i=%d: %s (%s)\nsrc\n%s\nhave\n%s\nwant\n%s\n",
i, name, fail, &b, &have, &want)
}
totalClose += n
}
if tolerance > 0 {
t.Logf("%d/%d total close calls", totalClose, len(blocks)*blockSize)
}
})
}
// differ returns the index of the first pair-wise elements in b0 and b1
// that differ by more than 'ok', along with the total number of elements
// that differ by at least ok ("close calls").
//
// There isn't a single definitive decoding of a given JPEG image,
// even before the YCbCr to RGB conversion; implementations
// can have different IDCT rounding errors.
// If there is a difference, differ returns the index of the first difference.
// Otherwise it returns -1.
func differ(b0, b1 *block, ok int32) int {
//
// If there are no differences, differ returns -1, 0.
func differ(b0, b1 *block, ok int32) (index, closeCalls int) {
index = -1
for i := range b0 {
delta := b0[i] - b1[i]
if delta < -ok || ok < delta {
return i
if index < 0 {
index = i
}
}
if delta <= -ok || ok <= delta {
closeCalls++
}
}
return -1
return
}
// alpha returns 1 if i is 0 and returns √2 otherwise.
@ -105,6 +159,14 @@ func alpha(i int) float64 {
return math.Sqrt2
}
// bigAlpha returns 1 if i is 0 and returns √2 otherwise.
func bigAlpha(i int) *big.Float {
if i == 0 {
return bigFloat1
}
return bigFloatSqrt2
}
var cosines = [32]float64{
+1.0000000000000000000000000000000000000000000000000000000000000000, // cos(π/16 * 0)
+0.9807852804032304491261822361342390369739337308933360950029160885, // cos(π/16 * 1)
@ -143,6 +205,57 @@ var cosines = [32]float64{
+0.9807852804032304491261822361342390369739337308933360950029160885, // cos(π/16 * 31)
}
func bigFloat(s string) *big.Float {
f, ok := new(big.Float).SetString(s)
if !ok {
panic("bad float")
}
return f
}
var (
bigFloat1 = big.NewFloat(1)
bigFloatSqrt2 = bigFloat("1.41421356237309504880168872420969807856967187537694807317667974")
)
var bigCosines = [32]*big.Float{
bigFloat("+1.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 0)
bigFloat("+0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 1)
bigFloat("+0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 2)
bigFloat("+0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 3)
bigFloat("+0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 4)
bigFloat("+0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 5)
bigFloat("+0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 6)
bigFloat("+0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 7)
bigFloat("-0.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 8)
bigFloat("-0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 9)
bigFloat("-0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 10)
bigFloat("-0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 11)
bigFloat("-0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 12)
bigFloat("-0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 13)
bigFloat("-0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 14)
bigFloat("-0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 15)
bigFloat("-1.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 16)
bigFloat("-0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 17)
bigFloat("-0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 18)
bigFloat("-0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 19)
bigFloat("-0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 20)
bigFloat("-0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 21)
bigFloat("-0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 22)
bigFloat("-0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 23)
bigFloat("+0.0000000000000000000000000000000000000000000000000000000000000000"), // cos(π/16 * 24)
bigFloat("+0.1950903220161282678482848684770222409276916177519548077545020894"), // cos(π/16 * 25)
bigFloat("+0.3826834323650897717284599840303988667613445624856270414338006356"), // cos(π/16 * 26)
bigFloat("+0.5555702330196022247428308139485328743749371907548040459241535282"), // cos(π/16 * 27)
bigFloat("+0.7071067811865475244008443621048490392848359376884740365883398689"), // cos(π/16 * 28)
bigFloat("+0.8314696123025452370787883776179057567385608119872499634461245902"), // cos(π/16 * 29)
bigFloat("+0.9238795325112867561281831893967882868224166258636424861150977312"), // cos(π/16 * 30)
bigFloat("+0.9807852804032304491261822361342390369739337308933360950029160885"), // cos(π/16 * 31)
}
// slowFDCT performs the 8*8 2-dimensional forward discrete cosine transform:
//
// dst[u,v] = (1/8) * Σ_x Σ_y alpha(u) * alpha(v) * src[x,y] *
@ -153,7 +266,7 @@ var cosines = [32]float64{
//
// b acts as both dst and src.
func slowFDCT(b *block) {
var dst [blockSize]float64
var dst block
for v := 0; v < 8; v++ {
for u := 0; u < 8; u++ {
sum := 0.0
@ -164,13 +277,40 @@ func slowFDCT(b *block) {
cosines[((2*y+1)*v)%32]
}
}
dst[8*v+u] = sum
dst[8*v+u] = int32(math.Round(sum))
}
}
// Convert from float64 to int32.
for i := range dst {
b[i] = int32(dst[i] + 0.5)
*b = dst
}
// slowerFDCT is slowFDCT but using big.Floats to validate slowFDCT.
func slowerFDCT(b *block) {
var dst block
for v := 0; v < 8; v++ {
for u := 0; u < 8; u++ {
sum := big.NewFloat(0)
for y := 0; y < 8; y++ {
for x := 0; x < 8; x++ {
f := big.NewFloat(float64(b[8*y+x] - 128))
f = new(big.Float).Mul(f, bigAlpha(u))
f = new(big.Float).Mul(f, bigAlpha(v))
f = new(big.Float).Mul(f, bigCosines[((2*x+1)*u)%32])
f = new(big.Float).Mul(f, bigCosines[((2*y+1)*v)%32])
sum = new(big.Float).Add(sum, f)
}
}
// Int64 truncates toward zero, so add ±0.5
// as needed to round
if sum.Sign() > 0 {
sum = new(big.Float).Add(sum, big.NewFloat(+0.5))
} else {
sum = new(big.Float).Add(sum, big.NewFloat(-0.5))
}
i, _ := sum.Int64()
dst[8*v+u] = int32(i)
}
}
*b = dst
}
// slowIDCT performs the 8*8 2-dimensional inverse discrete cosine transform:
@ -183,7 +323,7 @@ func slowFDCT(b *block) {
//
// b acts as both dst and src.
func slowIDCT(b *block) {
var dst [blockSize]float64
var dst block
for y := 0; y < 8; y++ {
for x := 0; x < 8; x++ {
sum := 0.0
@ -194,13 +334,41 @@ func slowIDCT(b *block) {
cosines[((2*y+1)*v)%32]
}
}
dst[8*y+x] = sum / 8
dst[8*y+x] = int32(math.Round(sum / 8))
}
}
// Convert from float64 to int32.
for i := range dst {
b[i] = int32(dst[i] + 0.5)
*b = dst
}
// slowerIDCT is slowIDCT but using big.Floats to validate slowIDCT.
func slowerIDCT(b *block) {
var dst block
for y := 0; y < 8; y++ {
for x := 0; x < 8; x++ {
sum := big.NewFloat(0)
for v := 0; v < 8; v++ {
for u := 0; u < 8; u++ {
f := big.NewFloat(float64(b[8*v+u]))
f = new(big.Float).Mul(f, bigAlpha(u))
f = new(big.Float).Mul(f, bigAlpha(v))
f = new(big.Float).Mul(f, bigCosines[((2*x+1)*u)%32])
f = new(big.Float).Mul(f, bigCosines[((2*y+1)*v)%32])
f = new(big.Float).Quo(f, big.NewFloat(8))
sum = new(big.Float).Add(sum, f)
}
}
// Int64 truncates toward zero, so add ±0.5
// as needed to round
if sum.Sign() > 0 {
sum = new(big.Float).Add(sum, big.NewFloat(+0.5))
} else {
sum = new(big.Float).Add(sum, big.NewFloat(-0.5))
}
i, _ := sum.Int64()
dst[8*y+x] = int32(i)
}
}
*b = dst
}
func (b *block) String() string {

View file

@ -154,8 +154,8 @@ func TestWriter(t *testing.T) {
continue
}
// Compare the average delta to the tolerance level.
if averageDelta(m0, m1) > tc.tolerance {
t.Errorf("%s, quality=%d: average delta is too high", tc.filename, tc.quality)
if d := averageDelta(m0, m1); d > tc.tolerance {
t.Errorf("%s, quality=%d: average delta is too high (%d > %d)", tc.filename, tc.quality, d, tc.tolerance)
continue
}
}