Skip to content

gh-98552: Fix preloading '__main__' with forkserver being broken for a #99515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Lib/multiprocessing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,20 @@ def set_executable(self, executable):
from .spawn import set_executable
set_executable(executable)

def set_forkserver_preload(self, module_names):
def set_forkserver_preload(self, module_names, raise_exceptions=False):
'''Set list of module names to try to load in forkserver process.
This is really just a hint.

If this method is not called, the default list of modules_names is
['__main__']. In most scenarios, callers will want to specify '__main__'
as the first entry in modules_names when calling this method.

By default, any exceptions from importing the specified module names
are suppressed. Set raise_exceptions = True to not suppress. If an
exception is raised and not suppressed, the forkserver will exit and
new process creation will fail.
'''
from .forkserver import set_forkserver_preload
set_forkserver_preload(module_names)
set_forkserver_preload(module_names, raise_exceptions)

def get_context(self, method=None):
if method is None:
Expand Down
55 changes: 41 additions & 14 deletions Lib/multiprocessing/forkserver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import base64
import errno
import json
import os
import selectors
import signal
Expand Down Expand Up @@ -38,6 +40,7 @@ def __init__(self):
self._inherited_fds = None
self._lock = threading.Lock()
self._preload_modules = ['__main__']
self._preload_modules_raise_exceptions = False

def _stop(self):
# Method used by unit tests to stop the server
Expand All @@ -59,11 +62,22 @@ def _stop_unlocked(self):
os.unlink(self._forkserver_address)
self._forkserver_address = None

def set_forkserver_preload(self, modules_names):
'''Set list of module names to try to load in forkserver process.'''
def set_forkserver_preload(self, modules_names, raise_exceptions = False):
'''Set list of module names to try to load in forkserver process.

If this method is not called, the default list of modules_names is
['__main__']. In most scenarios, callers will want to specify '__main__'
as the first entry in modules_names when calling this method.

By default, any exceptions from importing the specified module names
are suppressed. Set raise_exceptions = True to not suppress. If an
exception is raised and not suppressed, the forkserver will exit and
new process creation will fail.
'''
if not all(type(mod) is str for mod in self._preload_modules):
raise TypeError('module_names must be a list of strings')
self._preload_modules = modules_names
self._preload_modules_raise_exceptions = raise_exceptions

def get_inherited_fds(self):
'''Return list of fds inherited from parent process.
Expand Down Expand Up @@ -124,14 +138,19 @@ def ensure_running(self):
self._forkserver_pid = None

cmd = ('from multiprocessing.forkserver import main; ' +
'main(%d, %d, %r, **%r)')
'main(%d, %d, %r, %r, %r)')

spawn_data = spawn.get_preparation_data('ignore')

if self._preload_modules:
desired_keys = {'main_path', 'sys_path'}
data = spawn.get_preparation_data('ignore')
data = {x: y for x, y in data.items() if x in desired_keys}
else:
data = {}
#The authkey cannot be serialized. so clear the value from get_preparation_data
spawn_data.pop('authkey',None)

#The forkserver itself uses the fork start_method, so clear the value from get_preparation_data
spawn_data.pop('start_method',None)

spawn_data_json = json.dumps(spawn_data)
prepare_data_base64_encoded = base64.b64encode(
bytes(spawn_data_json,'utf-8')).decode()

with socket.socket(socket.AF_UNIX) as listener:
address = connection.arbitrary_address('AF_UNIX')
Expand All @@ -146,7 +165,7 @@ def ensure_running(self):
try:
fds_to_pass = [listener.fileno(), alive_r]
cmd %= (listener.fileno(), alive_r, self._preload_modules,
data)
self._preload_modules_raise_exceptions, prepare_data_base64_encoded)
exe = spawn.get_executable()
args = [exe] + util._args_from_interpreter_flags()
args += ['-c', cmd]
Expand All @@ -164,20 +183,25 @@ def ensure_running(self):
#
#

def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
def main(listener_fd, alive_r, preload, raise_import_error, prepare_data_base64_encoded):
'''Run forkserver.'''
if preload:
if '__main__' in preload and main_path is not None:
if prepare_data_base64_encoded is not None:
prepare_data = json.loads(base64.b64decode(prepare_data_base64_encoded).decode('utf-8'))
if '__main__' not in preload:
prepare_data.pop('init_main_from_path',None)
prepare_data.pop('init_main_from_name',None)
process.current_process()._inheriting = True
try:
spawn.import_main_path(main_path)
spawn.prepare(prepare_data)
finally:
del process.current_process()._inheriting
for modname in preload:
try:
__import__(modname)
except ImportError:
pass
if raise_import_error:
raise

util._close_stdin()

Expand Down Expand Up @@ -262,6 +286,9 @@ def sigchld_handler(*_unused):
len(fds)))
child_r, child_w, *fds = fds
s.close()
#Failure to flush these before fork can leave data in the buffers
#for unsuspecting children
util._flush_std_streams()
pid = os.fork()
if pid == 0:
# Child
Expand Down
14 changes: 11 additions & 3 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5274,11 +5274,19 @@ def test_preload_resources(self):
rc, out, err = test.support.script_helper.assert_python_ok(name)
out = out.decode()
err = err.decode()
if out.rstrip() != 'ok' or err != '':
print(out)
print(err)
expected = "mp_preload\nmp_preload\nmp_preload_import\nf\nf\nf"
if out.rstrip() != expected or err != '':
print("expected out: " + expected)
print("actual out : " + out)
print("err : " + err)
self.fail("failed spawning forkserver or grandchild")

def test_preload_exception(self):
if multiprocessing.get_start_method() != 'forkserver':
self.skipTest("test only relevant for 'forkserver' method")
name = os.path.join(os.path.dirname(__file__), 'mp_preload_exception.py')
for raise_exception in [0,1]:
rc, out, err = test.support.script_helper.assert_python_ok(name, str(raise_exception))

@unittest.skipIf(sys.platform == "win32",
"test semantics don't make sense on Windows")
Expand Down
35 changes: 25 additions & 10 deletions Lib/test/mp_preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,32 @@

multiprocessing.Lock()


#
# This test verifies that preload is behaving as expected. By preloading
# both __main__ and mp_preload_import, both this module and mp_preload_import
# should be loaded in the forkserver process when it serves new processes.
# This means that each new process and call to f() will not cause additional
# module loading.
#
# The expected output is then:
# mp_preload
# mp_preload
# mp_preload_import
# f
# f
# f
#
# Any deviation from this means something is broken.
#
def f():
print("ok")

import test.mp_preload_import
print('f')

print("mp_preload")
if __name__ == "__main__":
ctx = multiprocessing.get_context("forkserver")
modname = "test.mp_preload"
# Make sure it's importable
__import__(modname)
ctx.set_forkserver_preload([modname])
proc = ctx.Process(target=f)
proc.start()
proc.join()
ctx.set_forkserver_preload(["__main__","test.mp_preload_import"], True)
for i in range(3):
proc = ctx.Process(target=f)
proc.start()
proc.join()
26 changes: 26 additions & 0 deletions Lib/test/mp_preload_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import multiprocessing
import sys

#
# This test verifies that preload on a nonexistant module raises an exception
# that eventually leads to any new process start failing, when we specify that
# as the desired behavior.
#

def f():
print('f')

if __name__ == "__main__":
raise_exceptions = int(sys.argv[1])!=0
ctx = multiprocessing.get_context("forkserver")
ctx.set_forkserver_preload(["__main__","test.mp_preload_import_does_not_exist"], raise_exceptions)
proc = ctx.Process(target=f)
exception_thrown = False
try:
proc.start()
proc.join()
except Exception:
exception_thrown=True
if exception_thrown != raise_exceptions:
raise RuntimeError('Difference between exception_thrown and raise_exceptions')
print('done')
1 change: 1 addition & 0 deletions Lib/test/mp_preload_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
print('mp_preload_import')
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,7 @@ Trent Nelson
Andrew Nester
Osvaldo Santana Neto
Chad Netzer
Nick Neumann
Max Neunhöffer
Anthon van der Neut
George Neville-Neil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix preloading ``__main__`` with forkserver, and other related forkserver
issues