diff --git a/Lib/concurrent/futures/process.py b/Lib/concurrent/futures/process.py index aaa5151e017c0f..84b1b8bea929ad 100644 --- a/Lib/concurrent/futures/process.py +++ b/Lib/concurrent/futures/process.py @@ -53,6 +53,8 @@ import multiprocessing as mp from multiprocessing.connection import wait from multiprocessing.queues import Queue +from multiprocessing import context + import threading import weakref from functools import partial @@ -100,12 +102,20 @@ def _python_exit(): for t, _ in items: t.join() + # Controls how many more calls than processes will be queued in the call queue. # A smaller number will mean that processes spend more time idle waiting for # work while a larger number will make Future.cancel() succeed less frequently # (Futures in the call queue cannot be cancelled). EXTRA_QUEUED_CALLS = 1 +##### +_ForkingPickler = context.reduction.ForkingPickler +PICKLE_NONE = _ForkingPickler.dumps(None) +WORK_ID_SIZE = 8 +WORK_ID_ENC = "little" +SENTINEL_MSG = b'\x00' + # Hack to embed stringification of remote traceback in local traceback @@ -149,25 +159,6 @@ def __init__(self, work_id, fn, args, kwargs): self.kwargs = kwargs -class _SafeQueue(Queue): - """Safe Queue set exception to the future object linked to a job""" - def __init__(self, max_size=0, *, ctx, pending_work_items): - self.pending_work_items = pending_work_items - super().__init__(max_size, ctx=ctx) - - def _on_queue_feeder_error(self, e, obj): - if isinstance(obj, _CallItem): - tb = traceback.format_exception(type(e), e, e.__traceback__) - e.__cause__ = _RemoteTraceback('\n"""\n{}"""'.format(''.join(tb))) - work_item = self.pending_work_items.pop(obj.work_id, None) - # work_item can be None if another process terminated. In this case, - # the queue_manager_thread fails all work_items with BrokenProcessPool - if work_item is not None: - work_item.future.set_exception(e) - else: - super()._on_queue_feeder_error(e, obj) - - def _get_chunks(*iterables, chunksize): """ Iterates over zip()ed iterables in chunks. """ it = zip(*iterables) @@ -192,11 +183,14 @@ def _process_chunk(fn, chunk): def _sendback_result(result_queue, work_id, result=None, exception=None): """Safely send back the given result or exception""" try: - result_queue.put(_ResultItem(work_id, result=result, - exception=exception)) + serialize_res = _ForkingPickler.dumps( + _ResultItem(work_id, result=result, exception=exception)) except BaseException as e: - exc = _ExceptionWithTraceback(e, e.__traceback__) - result_queue.put(_ResultItem(work_id, exception=exc)) + serialize_res = _ForkingPickler.dumps(_ResultItem( + work_id, exception=_ExceptionWithTraceback(e, e.__traceback__) + )) + result_queue._put_bytes(work_id.to_bytes(WORK_ID_SIZE, WORK_ID_ENC) + + serialize_res) def _process_worker(call_queue, result_queue, initializer, initargs): @@ -221,18 +215,21 @@ def _process_worker(call_queue, result_queue, initializer, initargs): # mark the pool broken return while True: - call_item = call_queue.get(block=True) - if call_item is None: + serialized_item = call_queue._get_bytes(block=True) + if serialized_item == SENTINEL_MSG: # Wake up queue management thread - result_queue.put(os.getpid()) + result_queue._put_bytes( + os.getpid().to_bytes(WORK_ID_SIZE, WORK_ID_ENC)) return + work_id = int.from_bytes(serialized_item[:WORK_ID_SIZE], WORK_ID_ENC) + call_item = None try: + call_item = _ForkingPickler.loads(serialized_item[WORK_ID_SIZE:]) r = call_item.fn(*call_item.args, **call_item.kwargs) + _sendback_result(result_queue, work_id, result=r) except BaseException as e: exc = _ExceptionWithTraceback(e, e.__traceback__) - _sendback_result(result_queue, call_item.work_id, exception=exc) - else: - _sendback_result(result_queue, call_item.work_id, result=r) + _sendback_result(result_queue, work_id, exception=exc) # Liberate the resource as soon as possible, to avoid holding onto # open files or shared memory that is not needed anymore @@ -267,14 +264,27 @@ def _add_call_item_to_queue(pending_work_items, work_item = pending_work_items[work_id] if work_item.future.set_running_or_notify_cancel(): - call_queue.put(_CallItem(work_id, - work_item.fn, - work_item.args, - work_item.kwargs), - block=True) - else: - del pending_work_items[work_id] - continue + call_item = _CallItem(work_id, work_item.fn, work_item.args, + work_item.kwargs) + try: + msg = _ForkingPickler.dumps(call_item) + except BaseException as e: + tb = traceback.format_exception( + type(e), e, e.__traceback__) + e.__cause__ = _RemoteTraceback( + '\n"""\n{}"""'.format(''.join(tb))) + # work_item can be None if a process terminated and the + # executor is broken + if work_item is not None: + work_item.future.set_exception(e) + del work_item + + del pending_work_items[work_id] + continue + call_queue._put_bytes( + work_id.to_bytes(WORK_ID_SIZE, WORK_ID_ENC) + msg, + block=True) + def _queue_management_worker(executor_reference, @@ -321,7 +331,7 @@ def shutdown_worker(): while n_sentinels_sent < n_children_to_stop and n_children_alive > 0: for i in range(n_children_to_stop - n_sentinels_sent): try: - call_queue.put_nowait(None) + call_queue._put_bytes(SENTINEL_MSG, block=False) n_sentinels_sent += 1 except Full: break @@ -352,19 +362,22 @@ def shutdown_worker(): ready = wait(readers + worker_sentinels) cause = None - is_broken = True + thread_wakeup.clear() if result_reader in ready: try: - result_item = result_reader.recv() - is_broken = False + serialize_res = result_reader.recv_bytes() + work_id = int.from_bytes(serialize_res[:WORK_ID_SIZE], + WORK_ID_ENC) + result_item = work_id + if len(serialize_res) > WORK_ID_SIZE: + result_item = _ForkingPickler.loads( + serialize_res[WORK_ID_SIZE:]) except BaseException as e: - cause = traceback.format_exception(type(e), e, e.__traceback__) - + result_item = _ResultItem(work_id, exception=e) elif wakeup_reader in ready: is_broken = False result_item = None - thread_wakeup.clear() - if is_broken: + else: # Mark the process pool broken so that submits fail right now. executor = executor_reference() if executor is not None: @@ -531,9 +544,7 @@ def __init__(self, max_workers=None, mp_context=None, # prevent the worker processes from idling. But don't make it too big # because futures in the call queue cannot be cancelled. queue_size = self._max_workers + EXTRA_QUEUED_CALLS - self._call_queue = _SafeQueue( - max_size=queue_size, ctx=self._mp_context, - pending_work_items=self._pending_work_items) + self._call_queue = Queue(queue_size, ctx=self._mp_context) # Killed worker processes can produce spurious "broken pipe" # tracebacks in the queue's own worker thread. But we detect killed # processes anyway, so silence the tracebacks. diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index d66d37a5c3e2eb..52727e41857999 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -27,6 +27,22 @@ from .util import debug, info, Finalize, register_after_fork, is_exiting +# +# Sendable Object, with a serialization protocol +# + + +class _SendableObject(object): + def __init__(self, obj, serialization=None): + self.obj = obj + self.serialization = serialization + + def serialize(self): + if self.serialization: + return self.serialization(self.obj) + return self.obj + + # # Queue type using a pipe, buffer and thread # @@ -78,6 +94,10 @@ def _after_fork(self): self._poll = self._reader.poll def put(self, obj, block=True, timeout=None): + self._put_bytes(obj, block=block, timeout=timeout, + serialization=_ForkingPickler.dumps) + + def _put_bytes(self, obj, block=True, timeout=None, serialization=None): assert not self._closed, "Queue {0!r} has been closed".format(self) if not self._sem.acquire(block, timeout): raise Full @@ -85,10 +105,15 @@ def put(self, obj, block=True, timeout=None): with self._notempty: if self._thread is None: self._start_thread() - self._buffer.append(obj) + self._buffer.append(_SendableObject( + obj, serialization=serialization)) self._notempty.notify() def get(self, block=True, timeout=None): + return self._get_bytes(block=block, timeout=timeout, + deserialization=_ForkingPickler.loads) + + def _get_bytes(self, block=True, timeout=None, deserialization=None): if block and timeout is None: with self._rlock: res = self._recv_bytes() @@ -109,8 +134,10 @@ def get(self, block=True, timeout=None): self._sem.release() finally: self._rlock.release() - # unserialize the data after having released the lock - return _ForkingPickler.loads(res) + # un-serialize the data after having released the lock + if deserialization: + return deserialization(res) + return res def qsize(self): # Raises NotImplementedError on Mac OSX because of broken sem_getvalue() @@ -233,7 +260,7 @@ def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe, return # serialize the data before acquiring the lock - obj = _ForkingPickler.dumps(obj) + obj = obj.serialize() if wacquire is None: send_bytes(obj) else: @@ -255,7 +282,7 @@ def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe, info('error in queue thread: %s', e) return else: - onerror(e, obj) + onerror(e, obj.obj) @staticmethod def _on_queue_feeder_error(e, obj): @@ -299,7 +326,8 @@ def put(self, obj, block=True, timeout=None): with self._notempty, self._cond: if self._thread is None: self._start_thread() - self._buffer.append(obj) + self._buffer.append(_SendableObject( + obj, serialization=_ForkingPickler.dumps)) self._unfinished_tasks.release() self._notempty.notify() @@ -342,14 +370,25 @@ def __setstate__(self, state): self._poll = self._reader.poll def get(self): + # Get the object and deserialize it with the _ForkingPickler + return self._get_bytes(deserialization=_ForkingPickler.loads) + + def _get_bytes(self, deserialization=None): with self._rlock: res = self._reader.recv_bytes() # unserialize the data after having released the lock - return _ForkingPickler.loads(res) + if deserialization: + return deserialization(res) + return res def put(self, obj): + # Get the object and deserialize it with the _ForkingPickler + self._put_bytes(obj, serialization=_ForkingPickler.dumps) + + def _put_bytes(self, obj, serialization=None): # serialize the data before acquiring the lock - obj = _ForkingPickler.dumps(obj) + if serialization: + obj = serialization(obj) if self._wlock is None: # writes to a message oriented win32 pipe are atomic self._writer.send_bytes(obj) diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 05166b91ba832a..469381a6be740b 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -18,6 +18,7 @@ import random import logging import struct +import pickle import operator import weakref import test.support @@ -1066,6 +1067,24 @@ def _on_queue_feeder_error(e, obj): # Assert that the serialization and the hook have been called correctly self.assertTrue(not_serializable_obj.reduce_was_called) self.assertTrue(not_serializable_obj.on_queue_feeder_error_was_called) + + def test_queue_serialization(self): + # bpo-30006: verify feeder handles exceptions using the + # _on_queue_feeder_error hook. + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + q = self.Queue() + + # Custom serialization + def serialization(x): + return pickle.dumps(x * 2) + q._put_bytes(21, serialization=serialization) + self.assertEqual(q.get(), 42) + + # Custom bytes channels + q._put_bytes(bytes(42), serialization=None) + self.assertEqual(q._get_bytes(), bytes(42)) # # # diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py index 675cd7ae05e5fc..443a28f4f61e23 100644 --- a/Lib/test/test_concurrent_futures.py +++ b/Lib/test/test_concurrent_futures.py @@ -18,7 +18,7 @@ import time import unittest import weakref -from pickle import PicklingError +from pickle import PicklingError, UnpicklingError from concurrent import futures from concurrent.futures._base import ( @@ -888,11 +888,12 @@ def test_crash(self): crash_cases = [ # Check problem occuring while pickling a task in # the task_handler thread - (id, (ErrorAtPickle(),), PicklingError, "error at task pickle"), + (id, (ErrorAtPickle(),), PicklingError, + "error at task pickle"), # Check problem occuring while unpickling a task on workers - (id, (ExitAtUnpickle(),), BrokenProcessPool, + (id, (ExitAtUnpickle(),), SystemExit, "exit at task unpickle"), - (id, (ErrorAtUnpickle(),), BrokenProcessPool, + (id, (ErrorAtUnpickle(),), UnpicklingError, "error at task unpickle"), (id, (CrashAtUnpickle(),), BrokenProcessPool, "crash at task unpickle"), @@ -913,9 +914,9 @@ def test_crash(self): "error during result pickle on worker"), # Check problem occuring while unpickling a task in # the result_handler thread - (_return_instance, (ErrorAtUnpickle,), BrokenProcessPool, + (_return_instance, (ErrorAtUnpickle,), UnpicklingError, "error during result unpickle in result_handler"), - (_return_instance, (ExitAtUnpickle,), BrokenProcessPool, + (_return_instance, (ExitAtUnpickle,), SystemExit, "exit during result unpickle in result_handler") ] for func, args, error, name in crash_cases: diff --git a/Misc/NEWS.d/next/Library/2018-01-12-15-27-06.bpo-31699.l8g_ld.rst b/Misc/NEWS.d/next/Library/2018-01-12-15-27-06.bpo-31699.l8g_ld.rst new file mode 100644 index 00000000000000..3c1c293cc1e170 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-01-12-15-27-06.bpo-31699.l8g_ld.rst @@ -0,0 +1,2 @@ +Fix :class:`concurrent.futures.ProcessPoolExecutor` so unpickling errors do +not break the executor and are send back as errors on the faulty task.