mirror of
https://github.com/python/cpython.git
synced 2025-12-08 06:10:17 +00:00
gh-98896: resource_tracker: use json&base64 to allow arbitrary shared memory names (GH-138473)
This commit is contained in:
parent
70748bdbea
commit
c6f3dd6a50
3 changed files with 97 additions and 8 deletions
|
|
@ -15,6 +15,7 @@
|
||||||
# this resource tracker process, "killall python" would probably leave unlinked
|
# this resource tracker process, "killall python" would probably leave unlinked
|
||||||
# resources.
|
# resources.
|
||||||
|
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -22,6 +23,8 @@
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from . import spawn
|
from . import spawn
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
|
|
@ -196,6 +199,17 @@ def _launch(self):
|
||||||
finally:
|
finally:
|
||||||
os.close(r)
|
os.close(r)
|
||||||
|
|
||||||
|
def _make_probe_message(self):
|
||||||
|
"""Return a JSON-encoded probe message."""
|
||||||
|
return (
|
||||||
|
json.dumps(
|
||||||
|
{"cmd": "PROBE", "rtype": "noop"},
|
||||||
|
ensure_ascii=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
).encode("ascii")
|
||||||
|
|
||||||
def _ensure_running_and_write(self, msg=None):
|
def _ensure_running_and_write(self, msg=None):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._lock._recursion_count() > 1:
|
if self._lock._recursion_count() > 1:
|
||||||
|
|
@ -207,7 +221,7 @@ def _ensure_running_and_write(self, msg=None):
|
||||||
if self._fd is not None:
|
if self._fd is not None:
|
||||||
# resource tracker was launched before, is it still running?
|
# resource tracker was launched before, is it still running?
|
||||||
if msg is None:
|
if msg is None:
|
||||||
to_send = b'PROBE:0:noop\n'
|
to_send = self._make_probe_message()
|
||||||
else:
|
else:
|
||||||
to_send = msg
|
to_send = msg
|
||||||
try:
|
try:
|
||||||
|
|
@ -234,7 +248,7 @@ def _check_alive(self):
|
||||||
try:
|
try:
|
||||||
# We cannot use send here as it calls ensure_running, creating
|
# We cannot use send here as it calls ensure_running, creating
|
||||||
# a cycle.
|
# a cycle.
|
||||||
os.write(self._fd, b'PROBE:0:noop\n')
|
os.write(self._fd, self._make_probe_message())
|
||||||
except OSError:
|
except OSError:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
|
|
@ -253,11 +267,25 @@ def _write(self, msg):
|
||||||
assert nbytes == len(msg), f"{nbytes=} != {len(msg)=}"
|
assert nbytes == len(msg), f"{nbytes=} != {len(msg)=}"
|
||||||
|
|
||||||
def _send(self, cmd, name, rtype):
|
def _send(self, cmd, name, rtype):
|
||||||
msg = f"{cmd}:{name}:{rtype}\n".encode("ascii")
|
# POSIX guarantees that writes to a pipe of less than PIPE_BUF (512 on Linux)
|
||||||
if len(msg) > 512:
|
# bytes are atomic. Therefore, we want the message to be shorter than 512 bytes.
|
||||||
# posix guarantees that writes to a pipe of less than PIPE_BUF
|
# POSIX shm_open() and sem_open() require the name, including its leading slash,
|
||||||
# bytes are atomic, and that PIPE_BUF >= 512
|
# to be at most NAME_MAX bytes (255 on Linux)
|
||||||
raise ValueError('msg too long')
|
# With json.dump(..., ensure_ascii=True) every non-ASCII byte becomes a 6-char
|
||||||
|
# escape like \uDC80.
|
||||||
|
# As we want the overall message to be kept atomic and therefore smaller than 512,
|
||||||
|
# we encode encode the raw name bytes with URL-safe Base64 - so a 255 long name
|
||||||
|
# will not exceed 340 bytes.
|
||||||
|
b = name.encode('utf-8', 'surrogateescape')
|
||||||
|
if len(b) > 255:
|
||||||
|
raise ValueError('shared memory name too long (max 255 bytes)')
|
||||||
|
b64 = base64.urlsafe_b64encode(b).decode('ascii')
|
||||||
|
|
||||||
|
payload = {"cmd": cmd, "rtype": rtype, "base64_name": b64}
|
||||||
|
msg = (json.dumps(payload, ensure_ascii=True, separators=(",", ":")) + "\n").encode("ascii")
|
||||||
|
|
||||||
|
# The entire JSON message is guaranteed < PIPE_BUF (512 bytes) by construction.
|
||||||
|
assert len(msg) <= 512, f"internal error: message too long ({len(msg)} bytes)"
|
||||||
|
|
||||||
self._ensure_running_and_write(msg)
|
self._ensure_running_and_write(msg)
|
||||||
|
|
||||||
|
|
@ -290,7 +318,23 @@ def main(fd):
|
||||||
with open(fd, 'rb') as f:
|
with open(fd, 'rb') as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
try:
|
try:
|
||||||
cmd, name, rtype = line.strip().decode('ascii').split(':')
|
try:
|
||||||
|
obj = json.loads(line.decode('ascii'))
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("malformed resource_tracker message: %r" % (line,)) from e
|
||||||
|
|
||||||
|
cmd = obj["cmd"]
|
||||||
|
rtype = obj["rtype"]
|
||||||
|
b64 = obj.get("base64_name", "")
|
||||||
|
|
||||||
|
if not isinstance(cmd, str) or not isinstance(rtype, str) or not isinstance(b64, str):
|
||||||
|
raise ValueError("malformed resource_tracker fields: %r" % (obj,))
|
||||||
|
|
||||||
|
try:
|
||||||
|
name = base64.urlsafe_b64decode(b64).decode('utf-8', 'surrogateescape')
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError("malformed resource_tracker base64_name: %r" % (b64,)) from e
|
||||||
|
|
||||||
cleanup_func = _CLEANUP_FUNCS.get(rtype, None)
|
cleanup_func = _CLEANUP_FUNCS.get(rtype, None)
|
||||||
if cleanup_func is None:
|
if cleanup_func is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
|
||||||
|
|
@ -7364,3 +7364,46 @@ def test_forkpty(self):
|
||||||
res = assert_python_failure("-c", code, PYTHONWARNINGS='error')
|
res = assert_python_failure("-c", code, PYTHONWARNINGS='error')
|
||||||
self.assertIn(b'DeprecationWarning', res.err)
|
self.assertIn(b'DeprecationWarning', res.err)
|
||||||
self.assertIn(b'is multi-threaded, use of forkpty() may lead to deadlocks in the child', res.err)
|
self.assertIn(b'is multi-threaded, use of forkpty() may lead to deadlocks in the child', res.err)
|
||||||
|
|
||||||
|
@unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory")
|
||||||
|
class TestSharedMemoryNames(unittest.TestCase):
|
||||||
|
def test_that_shared_memory_name_with_colons_has_no_resource_tracker_errors(self):
|
||||||
|
# Test script that creates and cleans up shared memory with colon in name
|
||||||
|
test_script = textwrap.dedent("""
|
||||||
|
import sys
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Test various patterns of colons in names
|
||||||
|
test_names = [
|
||||||
|
"a:b",
|
||||||
|
"a:b:c",
|
||||||
|
"test:name:with:many:colons",
|
||||||
|
":starts:with:colon",
|
||||||
|
"ends:with:colon:",
|
||||||
|
"::double::colons::",
|
||||||
|
"name\\nwithnewline",
|
||||||
|
"name-with-trailing-newline\\n",
|
||||||
|
"\\nname-starts-with-newline",
|
||||||
|
"colons:and\\nnewlines:mix",
|
||||||
|
"multi\\nline\\nname",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in test_names:
|
||||||
|
try:
|
||||||
|
shm = shared_memory.SharedMemory(create=True, size=100, name=name)
|
||||||
|
shm.buf[:5] = b'hello' # Write something to the shared memory
|
||||||
|
shm.close()
|
||||||
|
shm.unlink()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error with name '{name}': {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("SUCCESS")
|
||||||
|
""")
|
||||||
|
|
||||||
|
rc, out, err = assert_python_ok("-c", test_script)
|
||||||
|
self.assertIn(b"SUCCESS", out)
|
||||||
|
self.assertNotIn(b"traceback", err.lower(), err)
|
||||||
|
self.assertNotIn(b"resource_tracker.py", err, err)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
Fix a failure in multiprocessing resource_tracker when SharedMemory names contain colons.
|
||||||
|
Patch by Rani Pinchuk.
|
||||||
Loading…
Add table
Add a link
Reference in a new issue