This commit is contained in:
Waterlens 2026-05-04 10:11:34 +10:00 committed by GitHub
commit d02fc57a64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -152,11 +152,17 @@ class _Block:
fallthrough: bool = True
# Whether this block can eventually reach the next uop (_JIT_CONTINUE):
hot: bool = False
# Whether this block should be emitted in the hot section. This is separate
# from "hot": some cold fallthrough bridges must stay in the hot layout.
layout_hot: bool | None = None
# Whether this original assembler metadata/tail block should be preserved
# even if it is unreachable.
is_metadata: bool = False
def resolve(self) -> typing.Self:
"""Find the first non-empty block reachable from this one."""
block = self
while block.link and not block.instructions:
while block.link and not block.instructions and block.fallthrough:
block = block.link
return block
@ -208,6 +214,8 @@ class Optimizer:
const_reloc = "<Not supported>"
_frame_pointer_modify: typing.ClassVar[re.Pattern[str]] = _RE_NEVER_MATCH
label_index: int = 0
_cold_start: _Block | None = dataclasses.field(init=False, default=None)
_jump_name = "<Not supported>"
def __post_init__(self) -> None:
# Split the code into a linked list of basic blocks. A basic block is an
@ -339,18 +347,224 @@ def _lookup_label(self, label: str) -> _Block:
def _is_far_target(self, label: str) -> bool:
return not label.startswith(self.label_prefix)
def _continuation(self) -> _Block:
return self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")
def _cold_start_block(self) -> _Block:
if self._cold_start is None:
label = f"{self.symbol_prefix}_JIT_COLD_START"
self._cold_start = self._lookup_label(label)
self._cold_start.noninstructions.append(f"{label}:")
self._cold_start.layout_hot = False
return self._cold_start
def _make_label(self) -> str:
label = f"{self.label_prefix}_JIT_LABEL_{self.label_index}"
self.label_index += 1
return label
def _ensure_label(self, block: _Block) -> str:
if block.label is None:
block.label = self._make_label()
self._labels[block.label] = block
block.noninstructions.insert(0, f"{block.label}:")
return block.label
def _make_jump(self, target: _Block, *, hot: bool) -> _Block:
label = self._ensure_label(target)
return _Block(
instructions=[
Instruction(
InstructionKind.JUMP,
self._jump_name,
f"\t{self._jump_name} {label}",
None,
label,
)
],
target=target,
fallthrough=False,
hot=hot,
layout_hot=True,
)
def _effective_layout_hot(self, block: _Block) -> bool:
resolved = block.resolve()
if resolved.layout_hot is not None:
return resolved.layout_hot
if block.layout_hot is not None:
return block.layout_hot
return resolved.hot
def _same_layout_section(self, left: _Block, right: _Block) -> bool:
return self._effective_layout_hot(left) == self._effective_layout_hot(right)
def _can_short_branch_to_layout(
self, inst: Instruction, source: _Block, target: _Block
) -> bool:
if inst.kind != InstructionKind.SHORT_BRANCH:
return True
return self._same_layout_section(source, target)
def _insert_fallthrough_bridge(self, block: _Block, target: _Block) -> _Block:
bridge = self._make_jump(target, hot=target.hot)
bridge.link = block.link
block.link = bridge
return bridge
def _ensure_hot_fallthrough(self, block: _Block) -> None:
if block is self._continuation() or block is self._cold_start:
return
fallthrough = block.link
if (
fallthrough is None
or not self._effective_layout_hot(block)
or not block.fallthrough
):
return
fallthrough_is_hot = self._effective_layout_hot(fallthrough)
target = block.target
inst = block.instructions[-1] if block.instructions else None
# Keep AArch64 short branches in the hot layout:
# tbz x0, #0, .Lcold -> tbnz x0, #0, .Lhot
# .Lhot: b .Lcold
# .Lhot:
if (
fallthrough_is_hot
and inst is not None
and inst.kind == InstructionKind.SHORT_BRANCH
and target is not None
and not self._effective_layout_hot(target)
):
fallthrough_label = self._ensure_label(fallthrough)
inverted = self._invert_branch(inst, fallthrough_label)
assert inverted is not None
bridge = self._make_jump(target, hot=target.hot)
bridge.link = fallthrough
block.instructions[-1] = inverted
block.target = fallthrough
block.link = bridge
return
if fallthrough_is_hot:
return
# Make a hot-to-cold fallthrough explicit, preferably by inverting:
# b.eq .Lhot -> b.ne .Lcold
# .Lcold: .Lhot:
if (
inst is not None
and inst.is_branch()
and target is not None
and self._effective_layout_hot(target)
):
fallthrough_label = self._ensure_label(fallthrough)
inverted = None
if self._can_short_branch_to_layout(inst, block, fallthrough):
inverted = self._invert_branch(inst, fallthrough_label)
if inverted is not None:
bridge = self._make_jump(target, hot=True)
bridge.link = fallthrough
block.instructions[-1] = inverted
block.target = fallthrough
block.link = bridge
return
# If no inversion is possible, preserve the old fallthrough with:
# b .Lcold
self._insert_fallthrough_bridge(block, fallthrough)
def _layout_units(self) -> list[tuple[bool, list[_Block]]]:
continuation = self._continuation()
cold_start = self._cold_start_block()
units: list[tuple[bool, list[_Block]]] = []
unit: list[_Block] = []
def finish_unit() -> None:
nonlocal unit
if unit:
layout_hot = self._effective_layout_hot(unit[0])
for unit_block in unit:
unit_block.layout_hot = layout_hot
units.append((layout_hot, unit))
unit = []
for block in self._layout_blocks():
if block is continuation or block is cold_start:
finish_unit()
continue
unit.append(block)
if block.instructions or not block.fallthrough:
finish_unit()
finish_unit()
return units
def _metadata_blocks(self) -> list[_Block]:
return [block for block in self._blocks() if block.is_metadata]
def _relink_blocks(self, blocks: list[_Block]) -> None:
for current, next_block in zip(blocks, blocks[1:]):
current.link = next_block
if blocks:
blocks[-1].link = None
def _partition_hot_cold_blocks(self) -> None:
# The entry point must remain in the hot layout, even when it can't
# reach _JIT_CONTINUE. The stencil parser expects _JIT_ENTRY at code
# offset 0.
entry_label = f"{self.symbol_prefix}_JIT_ENTRY"
for block in self._layout_blocks():
if block.label == entry_label:
block.layout_hot = True
for block in list(self._layout_blocks()):
self._ensure_hot_fallthrough(block)
continuation = self._continuation()
continuation.layout_hot = True
continuation.fallthrough = False
cold_start = self._cold_start_block()
cold_start.layout_hot = False
cold_start.fallthrough = True
units = self._layout_units()
hot_blocks = [
block for layout_hot, unit in units if layout_hot for block in unit
]
cold_blocks = [
block for layout_hot, unit in units if not layout_hot for block in unit
]
self._relink_blocks(
[
*hot_blocks,
continuation,
cold_start,
*cold_blocks,
*self._metadata_blocks(),
]
)
def _blocks(self) -> typing.Generator[_Block, None, None]:
block: _Block | None = self._root
while block:
yield block
block = block.link
def _layout_blocks(self) -> typing.Generator[_Block, None, None]:
for block in self._blocks():
if not block.is_metadata:
yield block
def _body(self) -> str:
lines = ["#" + line for line in self.text.splitlines()]
hot = True
hot: bool | None = True
for block in self._blocks():
if hot != block.hot:
hot = block.hot
layout_hot = block.layout_hot
if layout_hot is None:
layout_hot = block.hot
if hot != layout_hot:
hot = layout_hot
# Make it easy to tell at a glance where cold code is:
lines.append(f"# JIT: {'HOT' if hot else 'COLD'} ".ljust(80, "#"))
lines.extend(block.noninstructions)
@ -378,12 +592,17 @@ def _insert_continue_label(self) -> None:
continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")
assert continuation.label
continuation.noninstructions.append(f"{continuation.label}:")
continuation.layout_hot = True
tail = end.link
while tail:
tail.is_metadata = True
tail = tail.link
end.link, continuation.link = continuation, end.link
def _mark_hot_blocks(self) -> None:
# Start with the last block, and perform a DFS to find all blocks that
# can eventually reach it:
todo = list(self._blocks())[-1:]
# Start with the continuation block, and perform a DFS to find all
# blocks that can eventually reach it:
todo = [self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")]
while todo:
block = todo.pop()
block.hot = True
@ -413,11 +632,14 @@ def _invert_hot_branches(self) -> None:
and len(jump.instructions) == 1
and list(self._predecessors(jump)) == [branch]
):
jump.layout_hot = True
assert jump.target.label
assert branch.target.label
inverted = self._invert_branch(
branch.instructions[-1], jump.target.label
)
inst = branch.instructions[-1]
if inst.kind == InstructionKind.SHORT_BRANCH:
inverted = None
else:
inverted = self._invert_branch(inst, jump.target.label)
# Check to see if the branch can even be inverted:
if inverted is None:
continue
@ -427,10 +649,12 @@ def _invert_hot_branches(self) -> None:
)
branch.target, jump.target = jump.target, branch.target
jump.hot = True
jump.layout_hot = True
def _remove_redundant_jumps(self) -> None:
# Zero-length jumps can be introduced by _insert_continue_label and
# _invert_hot_branches:
continuation = self._continuation()
for block in self._blocks():
target = block.target
if target is None:
@ -441,7 +665,14 @@ def _remove_redundant_jumps(self) -> None:
# FOO:
# After:
# FOO:
if block.link and target is block.link.resolve():
if (
block.link
and target is block.link.resolve()
and (
self._same_layout_section(block, target)
or target is continuation
)
):
block.target = None
block.fallthrough = True
block.instructions.pop()
@ -459,12 +690,15 @@ def _remove_redundant_jumps(self) -> None:
):
assert target.target is not None
assert target.target.label is not None
inst = block.instructions[-1]
if block.instructions[
-1
].kind == InstructionKind.SHORT_BRANCH and self._is_far_target(
target.target.label
):
continue
if not self._can_short_branch_to_layout(inst, block, target.target):
continue
block.target = target.target
block.instructions[-1] = block.instructions[-1].update_target(
target.target.label
@ -488,20 +722,38 @@ def _find_live_blocks(self) -> set[_Block]:
def _remove_unreachable(self) -> None:
live = self._find_live_blocks()
continuation = self._lookup_label(f"{self.label_prefix}_JIT_CONTINUE")
# Keep blocks after continuation as they may contain data and
# metadata that the assembler needs
continuation = self._continuation()
cont_or_cold_blocks = {continuation}
if self._cold_start is not None:
cont_or_cold_blocks.add(self._cold_start)
# Keep only the original assembler tail. Cold code after _JIT_CONTINUE
# is ordinary code and can be removed when unreachable.
prev: _Block | None = None
block = self._root
while block is not continuation:
# We now walk the whole list, so keep explicit sentinel checks in place
# of the old "stop at _JIT_CONTINUE" loop invariant.
seen_continuation = False
seen_cold_start = self._cold_start is None
while block is not None:
if block is continuation:
seen_continuation = True
if block is self._cold_start:
seen_cold_start = True
next = block.link
assert next is not None
if not block in live and prev:
if (
block not in live
and block not in cont_or_cold_blocks
and not block.is_metadata
and prev is not None
):
prev.link = next
else:
prev = block
block = next
assert prev.link is block
if prev is not None:
assert prev.link is block
assert seen_continuation
assert seen_cold_start
def _fixup_external_labels(self) -> None:
if self._supports_external_relocations:
@ -544,8 +796,10 @@ def run(self) -> None:
self._mark_hot_blocks()
# Removing branches can expose opportunities for more branch removal.
# Repeat a few times. 2 would probably do, but it's fast enough with 4.
for _ in range(4):
for iter in range(4):
self._invert_hot_branches()
if iter == 0:
self._partition_hot_cold_blocks()
self._remove_redundant_jumps()
self._remove_unreachable()
self._fixup_external_labels()
@ -559,6 +813,7 @@ class OptimizerAArch64(Optimizer): # pylint: disable = too-few-public-methods
_branches = _AARCH64_BRANCHES
_short_branches = _AARCH64_SHORT_BRANCHES
_jump_name = "b"
# Mach-O does not support the 19 bit branch locations needed for branch reordering
_supports_external_relocations = False
_branch_patterns = [name.replace(".", r"\.") for name in _AARCH64_BRANCHES]
@ -776,6 +1031,7 @@ class OptimizerX86(Optimizer): # pylint: disable = too-few-public-methods
_branches = _X86_BRANCHES
_short_branches = {}
_jump_name = "jmp"
_re_branch = re.compile(
rf"\s*(?P<instruction>{'|'.join(_X86_BRANCHES)})\s+(?P<target>[\w.]+)"
)