Revert "internal/sync: optimize CompareAndSwap and Swap"

This reverts CL 606462.

Reason for revert: Breaks atomicity between operations. See #70970.

Change-Id: I1a899f2784da5a0f9da3193e3267275c23aea661
Reviewed-on: https://go-review.googlesource.com/c/go/+/638615
Auto-Submit: Michael Knyszek <mknyszek@google.com>
Commit-Queue: Michael Knyszek <mknyszek@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Michael Knyszek 2024-12-23 17:59:28 -08:00 committed by Gopher Robot
parent 705b5a569a
commit c112c0af13

View file

@ -219,22 +219,12 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
slot = &i.children[(hash>>hashShift)&nChildrenMask] slot = &i.children[(hash>>hashShift)&nChildrenMask]
n = slot.Load() n = slot.Load()
if n == nil { if n == nil || n.isEntry {
// We found a nil slot which is a candidate for insertion, // We found a nil slot which is a candidate for insertion,
// or an existing entry that we'll replace. // or an existing entry that we'll replace.
haveInsertPoint = true haveInsertPoint = true
break break
} }
if n.isEntry {
// Swap if the keys compare.
old, swapped := n.entry().swap(key, new)
if swapped {
return old, true
}
// If we fail, that means we should try to insert.
haveInsertPoint = true
break
}
i = n.indirect() i = n.indirect()
} }
if !haveInsertPoint { if !haveInsertPoint {
@ -261,10 +251,11 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) {
var zero V var zero V
var oldEntry *entry[K, V] var oldEntry *entry[K, V]
if n != nil { if n != nil {
// Between before and now, something got inserted. Swap if the keys compare. // Swap if the keys compare.
oldEntry = n.entry() oldEntry = n.entry()
old, swapped := oldEntry.swap(key, new) newEntry, old, swapped := oldEntry.swap(key, new)
if swapped { if swapped {
slot.Store(&newEntry.node)
return old, true return old, true
} }
} }
@ -292,30 +283,25 @@ func (ht *HashTrieMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
panic("called CompareAndSwap when value is not of comparable type") panic("called CompareAndSwap when value is not of comparable type")
} }
hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed)
for {
// Find the key or return if it's not there.
i := ht.root.Load()
hashShift := 8 * goarch.PtrSize
found := false
for hashShift != 0 {
hashShift -= nChildrenLog2
slot := &i.children[(hash>>hashShift)&nChildrenMask] // Find a node with the key and compare with it. n != nil if we found the node.
n := slot.Load() i, _, slot, n := ht.find(key, hash, ht.valEqual, old)
if n == nil { if i != nil {
// Nothing to compare with. Give up. defer i.mu.Unlock()
return false
}
if n.isEntry {
// We found an entry. Try to compare and swap directly.
return n.entry().compareAndSwap(key, old, new, ht.valEqual)
}
i = n.indirect()
}
if !found {
panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating")
}
} }
if n == nil {
return false
}
// Try to swap the entry.
e, swapped := n.entry().compareAndSwap(key, old, new, ht.valEqual)
if !swapped {
// Nothing was actually swapped, which means the node is no longer there.
return false
}
// Store the entry back because it changed.
slot.Store(&e.node)
return true
} }
// LoadAndDelete deletes the value for a key, returning the previous value if any. // LoadAndDelete deletes the value for a key, returning the previous value if any.
@ -523,7 +509,7 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V)
} }
e := n.entry() e := n.entry()
for e != nil { for e != nil {
if !yield(e.key, *e.value.Load()) { if !yield(e.key, e.value) {
return false return false
} }
e = e.overflow.Load() e = e.overflow.Load()
@ -579,22 +565,21 @@ type entry[K comparable, V any] struct {
node[K, V] node[K, V]
overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions. overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions.
key K key K
value atomic.Pointer[V] value V
} }
func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] { func newEntryNode[K comparable, V any](key K, value V) *entry[K, V] {
e := &entry[K, V]{ return &entry[K, V]{
node: node[K, V]{isEntry: true}, node: node[K, V]{isEntry: true},
key: key, key: key,
value: value,
} }
e.value.Store(&value)
return e
} }
func (e *entry[K, V]) lookup(key K) (V, bool) { func (e *entry[K, V]) lookup(key K) (V, bool) {
for e != nil { for e != nil {
if e.key == key { if e.key == key {
return *e.value.Load(), true return e.value, true
} }
e = e.overflow.Load() e = e.overflow.Load()
} }
@ -603,87 +588,69 @@ func (e *entry[K, V]) lookup(key K) (V, bool) {
func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) { func (e *entry[K, V]) lookupWithValue(key K, value V, valEqual equalFunc) (V, bool) {
for e != nil { for e != nil {
oldp := e.value.Load() if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value)))) {
if e.key == key && (valEqual == nil || valEqual(unsafe.Pointer(oldp), abi.NoEscape(unsafe.Pointer(&value)))) { return e.value, true
return *oldp, true
} }
e = e.overflow.Load() e = e.overflow.Load()
} }
return *new(V), false return *new(V), false
} }
// swap replaces a value in the overflow chain if keys compare equal. // swap replaces an entry in the overflow chain if keys compare equal. Returns the new entry chain,
// Returns the old value, and whether or not anything was swapped. // the old value, and whether or not anything was swapped.
// //
// swap must be called under the mutex of the indirect node which e is a child of. // swap must be called under the mutex of the indirect node which e is a child of.
func (head *entry[K, V]) swap(key K, newv V) (V, bool) { func (head *entry[K, V]) swap(key K, new V) (*entry[K, V], V, bool) {
if head.key == key { if head.key == key {
vp := new(V) // Return the new head of the list.
*vp = newv e := newEntryNode(key, new)
oldp := head.value.Swap(vp) if chain := head.overflow.Load(); chain != nil {
return *oldp, true e.overflow.Store(chain)
}
return e, head.value, true
} }
i := &head.overflow i := &head.overflow
e := i.Load() e := i.Load()
for e != nil { for e != nil {
if e.key == key { if e.key == key {
vp := new(V) eNew := newEntryNode(key, new)
*vp = newv eNew.overflow.Store(e.overflow.Load())
oldp := e.value.Swap(vp) i.Store(eNew)
return *oldp, true return head, e.value, true
} }
i = &e.overflow i = &e.overflow
e = e.overflow.Load() e = e.overflow.Load()
} }
var zero V var zero V
return zero, false return head, zero, false
} }
// compareAndSwap replaces a value for a matching key and existing value in the overflow chain. // compareAndSwap replaces an entry in the overflow chain if both the key and value compare
// Returns whether or not anything was swapped. // equal. Returns the new entry chain and whether or not anything was swapped.
// //
// compareAndSwap must be called under the mutex of the indirect node which e is a child of. // compareAndSwap must be called under the mutex of the indirect node which e is a child of.
func (head *entry[K, V]) compareAndSwap(key K, oldv, newv V, valEqual equalFunc) bool { func (head *entry[K, V]) compareAndSwap(key K, old, new V, valEqual equalFunc) (*entry[K, V], bool) {
var vbox *V if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&old))) {
outerLoop: // Return the new head of the list.
for { e := newEntryNode(key, new)
oldvp := head.value.Load() if chain := head.overflow.Load(); chain != nil {
if head.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) { e.overflow.Store(chain)
// Return the new head of the list.
if vbox == nil {
// Delay explicit creation of a new value to hold newv. If we just pass &newv
// to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
vbox = new(V)
*vbox = newv
}
if head.value.CompareAndSwap(oldvp, vbox) {
return true
}
// We need to restart from the head of the overflow list in case, due to a removal, a node
// is moved up the list and we miss it.
continue outerLoop
} }
i := &head.overflow return e, true
e := i.Load()
for e != nil {
oldvp := e.value.Load()
if e.key == key && valEqual(unsafe.Pointer(oldvp), abi.NoEscape(unsafe.Pointer(&oldv))) {
if vbox == nil {
// Delay explicit creation of a new value to hold newv. If we just pass &newv
// to CompareAndSwap, then newv will unconditionally escape, even if the CAS fails.
vbox = new(V)
*vbox = newv
}
if e.value.CompareAndSwap(oldvp, vbox) {
return true
}
continue outerLoop
}
i = &e.overflow
e = e.overflow.Load()
}
return false
} }
i := &head.overflow
e := i.Load()
for e != nil {
if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&old))) {
eNew := newEntryNode(key, new)
eNew.overflow.Store(e.overflow.Load())
i.Store(eNew)
return head, true
}
i = &e.overflow
e = e.overflow.Load()
}
return head, false
} }
// loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new // loadAndDelete deletes an entry in the overflow chain by key. Returns the value for the key, the new
@ -693,14 +660,14 @@ outerLoop:
func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) { func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
if head.key == key { if head.key == key {
// Drop the head of the list. // Drop the head of the list.
return *head.value.Load(), head.overflow.Load(), true return head.value, head.overflow.Load(), true
} }
i := &head.overflow i := &head.overflow
e := i.Load() e := i.Load()
for e != nil { for e != nil {
if e.key == key { if e.key == key {
i.Store(e.overflow.Load()) i.Store(e.overflow.Load())
return *e.value.Load(), head, true return e.value, head, true
} }
i = &e.overflow i = &e.overflow
e = e.overflow.Load() e = e.overflow.Load()
@ -713,14 +680,14 @@ func (head *entry[K, V]) loadAndDelete(key K) (V, *entry[K, V], bool) {
// //
// compareAndDelete must be called under the mutex of the indirect node which e is a child of. // compareAndDelete must be called under the mutex of the indirect node which e is a child of.
func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) { func (head *entry[K, V]) compareAndDelete(key K, value V, valEqual equalFunc) (*entry[K, V], bool) {
if head.key == key && valEqual(unsafe.Pointer(head.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { if head.key == key && valEqual(unsafe.Pointer(&head.value), abi.NoEscape(unsafe.Pointer(&value))) {
// Drop the head of the list. // Drop the head of the list.
return head.overflow.Load(), true return head.overflow.Load(), true
} }
i := &head.overflow i := &head.overflow
e := i.Load() e := i.Load()
for e != nil { for e != nil {
if e.key == key && valEqual(unsafe.Pointer(e.value.Load()), abi.NoEscape(unsafe.Pointer(&value))) { if e.key == key && valEqual(unsafe.Pointer(&e.value), abi.NoEscape(unsafe.Pointer(&value))) {
i.Store(e.overflow.Load()) i.Store(e.overflow.Load())
return head, true return head, true
} }