diff --git a/src/hash/maphash/maphash.go b/src/hash/maphash/maphash.go index c45964f89e4..5cc0c581c7c 100644 --- a/src/hash/maphash/maphash.go +++ b/src/hash/maphash/maphash.go @@ -13,7 +13,10 @@ // package maphash -import "unsafe" +import ( + "internal/unsafeheader" + "unsafe" +) // A Seed is a random value that selects the specific hash function // computed by a Hash. If two Hashes use the same Seeds, they @@ -54,13 +57,19 @@ type Seed struct { // If multiple goroutines must compute the same seeded hash, // each can declare its own Hash and call SetSeed with a common Seed. type Hash struct { - _ [0]func() // not comparable - seed Seed // initial seed used for this hash - state Seed // current hash of all flushed bytes - buf [64]byte // unflushed byte buffer - n int // number of unflushed bytes + _ [0]func() // not comparable + seed Seed // initial seed used for this hash + state Seed // current hash of all flushed bytes + buf [bufSize]byte // unflushed byte buffer + n int // number of unflushed bytes } +// bufSize is the size of the Hash write buffer. +// The buffer ensures that writes depend only on the sequence of bytes, +// not the sequence of WriteByte/Write/WriteString calls, +// by always calling rthash with a full buffer (except for the tail). +const bufSize = 64 + // initSeed seeds the hash if necessary. // initSeed is called lazily before any operation that actually uses h.seed/h.state. // Note that this does not include Write/WriteByte/WriteString in the case @@ -89,27 +98,58 @@ func (h *Hash) WriteByte(b byte) error { // It always writes all of b and never fails; the count and error result are for implementing io.Writer. func (h *Hash) Write(b []byte) (int, error) { size := len(b) - for h.n+len(b) > len(h.buf) { + // Deal with bytes left over in h.buf. + // h.n <= bufSize is always true. + // Checking it is ~free and it lets the compiler eliminate a bounds check. + if h.n > 0 && h.n <= bufSize { k := copy(h.buf[h.n:], b) - h.n = len(h.buf) + h.n += k + if h.n < bufSize { + // Copied the entirety of b to h.buf. + return size, nil + } b = b[k:] h.flush() + // No need to set h.n = 0 here; it happens just before exit. } - h.n += copy(h.buf[h.n:], b) + // Process as many full buffers as possible, without copying, and calling initSeed only once. + if len(b) > bufSize { + h.initSeed() + for len(b) > bufSize { + h.state.s = rthash(&b[0], bufSize, h.state.s) + b = b[bufSize:] + } + } + // Copy the tail. + copy(h.buf[:], b) + h.n = len(b) return size, nil } // WriteString adds the bytes of s to the sequence of bytes hashed by h. // It always writes all of s and never fails; the count and error result are for implementing io.StringWriter. func (h *Hash) WriteString(s string) (int, error) { + // WriteString mirrors Write. See Write for comments. size := len(s) - for h.n+len(s) > len(h.buf) { + if h.n > 0 && h.n <= bufSize { k := copy(h.buf[h.n:], s) - h.n = len(h.buf) + h.n += k + if h.n < bufSize { + return size, nil + } s = s[k:] h.flush() } - h.n += copy(h.buf[h.n:], s) + if len(s) > bufSize { + h.initSeed() + for len(s) > bufSize { + ptr := (*byte)((*unsafeheader.String)(unsafe.Pointer(&s)).Data) + h.state.s = rthash(ptr, bufSize, h.state.s) + s = s[bufSize:] + } + } + copy(h.buf[:], s) + h.n = len(s) return size, nil } @@ -147,7 +187,7 @@ func (h *Hash) flush() { panic("maphash: flush of partially full buffer") } h.initSeed() - h.state.s = rthash(h.buf[:], h.state.s) + h.state.s = rthash(&h.buf[0], h.n, h.state.s) h.n = 0 } @@ -160,7 +200,7 @@ func (h *Hash) flush() { // by using bit masking, shifting, or modular arithmetic. func (h *Hash) Sum64() uint64 { h.initSeed() - return rthash(h.buf[:h.n], h.state.s) + return rthash(&h.buf[0], h.n, h.state.s) } // MakeSeed returns a new random seed. @@ -181,18 +221,18 @@ func MakeSeed() Seed { //go:linkname runtime_fastrand runtime.fastrand func runtime_fastrand() uint32 -func rthash(b []byte, seed uint64) uint64 { - if len(b) == 0 { +func rthash(ptr *byte, len int, seed uint64) uint64 { + if len == 0 { return seed } // The runtime hasher only works on uintptr. For 64-bit // architectures, we use the hasher directly. Otherwise, // we use two parallel hashers on the lower and upper 32 bits. if unsafe.Sizeof(uintptr(0)) == 8 { - return uint64(runtime_memhash(unsafe.Pointer(&b[0]), uintptr(seed), uintptr(len(b)))) + return uint64(runtime_memhash(unsafe.Pointer(ptr), uintptr(seed), uintptr(len))) } - lo := runtime_memhash(unsafe.Pointer(&b[0]), uintptr(seed), uintptr(len(b))) - hi := runtime_memhash(unsafe.Pointer(&b[0]), uintptr(seed>>32), uintptr(len(b))) + lo := runtime_memhash(unsafe.Pointer(ptr), uintptr(seed), uintptr(len)) + hi := runtime_memhash(unsafe.Pointer(ptr), uintptr(seed>>32), uintptr(len)) return uint64(hi)<<32 | uint64(lo) } diff --git a/src/hash/maphash/maphash_test.go b/src/hash/maphash/maphash_test.go index daf6eb47866..78cdfc0e737 100644 --- a/src/hash/maphash/maphash_test.go +++ b/src/hash/maphash/maphash_test.go @@ -5,6 +5,7 @@ package maphash import ( + "bytes" "hash" "testing" ) @@ -34,19 +35,57 @@ func TestSeededHash(t *testing.T) { } func TestHashGrouping(t *testing.T) { - b := []byte("foo") - h1 := new(Hash) - h2 := new(Hash) - h2.SetSeed(h1.Seed()) - h1.Write(b) - for _, x := range b { - err := h2.WriteByte(x) + b := bytes.Repeat([]byte("foo"), 100) + hh := make([]*Hash, 7) + for i := range hh { + hh[i] = new(Hash) + } + for _, h := range hh[1:] { + h.SetSeed(hh[0].Seed()) + } + hh[0].Write(b) + hh[1].WriteString(string(b)) + + writeByte := func(h *Hash, b byte) { + err := h.WriteByte(b) if err != nil { t.Fatalf("WriteByte: %v", err) } } - if h1.Sum64() != h2.Sum64() { - t.Errorf("hash of \"foo\" and \"f\",\"o\",\"o\" not identical") + writeSingleByte := func(h *Hash, b byte) { + _, err := h.Write([]byte{b}) + if err != nil { + t.Fatalf("Write single byte: %v", err) + } + } + writeStringSingleByte := func(h *Hash, b byte) { + _, err := h.WriteString(string([]byte{b})) + if err != nil { + t.Fatalf("WriteString single byte: %v", err) + } + } + + for i, x := range b { + writeByte(hh[2], x) + writeSingleByte(hh[3], x) + if i == 0 { + writeByte(hh[4], x) + } else { + writeSingleByte(hh[4], x) + } + writeStringSingleByte(hh[5], x) + if i == 0 { + writeByte(hh[6], x) + } else { + writeStringSingleByte(hh[6], x) + } + } + + sum := hh[0].Sum64() + for i, h := range hh { + if sum != h.Sum64() { + t.Errorf("hash %d not identical to a single Write", i) + } } }