Skip to content

Commit 1670cf8

Browse files
authored
Ensure resumed flight tasks are still fetched (#5426)
1 parent cdc68cc commit 1670cf8

File tree

4 files changed

+224
-124
lines changed

4 files changed

+224
-124
lines changed

distributed/tests/test_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102

103103
@gen_cluster(client=True)
104104
async def test_submit(c, s, a, b):
105-
x = c.submit(inc, 10)
105+
x = c.submit(inc, 10, key="x")
106106
assert not x.done()
107107

108108
assert isinstance(x, Future)
@@ -112,7 +112,7 @@ async def test_submit(c, s, a, b):
112112
assert result == 11
113113
assert x.done()
114114

115-
y = c.submit(inc, 20)
115+
y = c.submit(inc, 20, key="y")
116116
z = c.submit(add, x, y)
117117

118118
result = await z

distributed/tests/test_steal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,8 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
987987
slowinc,
988988
range(10),
989989
key=[f"f1-{ix}" for ix in range(10)],
990+
workers=[w0.address],
991+
allow_other_workers=True,
990992
)
991993
while not w0.active_keys:
992994
await asyncio.sleep(0.01)

distributed/tests/test_worker.py

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,14 +1804,14 @@ async def test_story_with_deps(c, s, a, b):
18041804
stimulus_ids.add(msg[-2])
18051805
pruned_story.append(tuple(pruned_msg[:-2]))
18061806

1807-
assert len(stimulus_ids) == 3
1807+
assert len(stimulus_ids) == 3, stimulus_ids
18081808
stimulus_id = pruned_story[0][-1]
18091809
assert isinstance(stimulus_id, str)
18101810
assert stimulus_id.startswith("compute-task")
18111811
# This is a simple transition log
18121812
expected_story = [
18131813
(key, "compute-task"),
1814-
(key, "released", "waiting", {}),
1814+
(key, "released", "waiting", {dep.key: "fetch"}),
18151815
(key, "waiting", "ready", {}),
18161816
(key, "ready", "executing", {}),
18171817
(key, "put-in-memory"),
@@ -1832,11 +1832,11 @@ async def test_story_with_deps(c, s, a, b):
18321832
stimulus_ids.add(msg[-2])
18331833
pruned_story.append(tuple(pruned_msg[:-2]))
18341834

1835-
assert len(stimulus_ids) == 3
1835+
assert len(stimulus_ids) == 2, stimulus_ids
18361836
stimulus_id = pruned_story[0][-1]
18371837
assert isinstance(stimulus_id, str)
18381838
expected_story = [
1839-
(dep_story, "register-replica", "released"),
1839+
(dep_story, "ensure-task-exists", "released"),
18401840
(dep_story, "released", "fetch", {}),
18411841
(
18421842
"gather-dependencies",
@@ -2794,7 +2794,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
27942794
_acquire_replicas(s, b, fut)
27952795

27962796
await futC
2797-
while fut.key not in b.tasks:
2797+
while fut.key not in b.tasks or not b.tasks[fut.key].state == "memory":
27982798
await asyncio.sleep(0.005)
27992799
assert len(s.who_has[fut.key]) == 2
28002800

@@ -3082,12 +3082,14 @@ def clear_leak():
30823082
]
30833083

30843084

3085-
async def _wait_for_flight(key, worker):
3086-
while key not in worker.tasks or worker.tasks[key].state != "flight":
3085+
async def _wait_for_state(key: str, worker: Worker, state: str):
3086+
# Keep the sleep interval at 0 since the tests using this are very sensitive
3087+
# about timing. they intend to capture loop cycles after this specific
3088+
# condition was set
3089+
while key not in worker.tasks or worker.tasks[key].state != state:
30873090
await asyncio.sleep(0)
30883091

30893092

3090-
@pytest.mark.xfail(reason="#5406")
30913093
@gen_cluster(client=True)
30923094
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a, b):
30933095
"""At time of writing, the gather_dep implementation filtered tasks again
@@ -3107,21 +3109,26 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a,
31073109

31083110
fut2_key = fut2.key
31093111

3110-
await _wait_for_flight(fut2_key, b)
3112+
await _wait_for_state(fut2_key, b, "flight")
3113+
while not mocked_gather.call_args:
3114+
await asyncio.sleep(0)
31113115

31123116
fut4.release()
31133117
while fut4.key in b.tasks:
31143118
await asyncio.sleep(0)
31153119

3116-
story_before = b.story(fut2.key)
3117-
assert fut2.key in mocked_gather.call_args.kwargs["to_gather"]
3118-
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
3119-
story_after = b.story(fut2.key)
3120-
assert story_before == story_after
3120+
assert b.tasks[fut2.key].state == "cancelled"
3121+
args, kwargs = mocked_gather.call_args
3122+
assert fut2.key in kwargs["to_gather"]
3123+
3124+
await Worker.gather_dep(b, *args, **kwargs)
3125+
assert fut2.key not in b.tasks
3126+
f2_story = b.story(fut2.key)
3127+
assert f2_story
3128+
assert not any("missing-dep" in msg for msg in b.story(fut2.key))
31213129
await fut3
31223130

31233131

3124-
@pytest.mark.xfail(reason="#5406")
31253132
@gen_cluster(
31263133
client=True,
31273134
config={
@@ -3137,13 +3144,55 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b):
31373144

31383145
fut1_key = fut1.key
31393146

3140-
await _wait_for_flight(fut1_key, b)
3147+
await _wait_for_state(fut1_key, b, "flight")
3148+
while not mocked_gather.call_args:
3149+
await asyncio.sleep(0)
31413150

31423151
fut2.release()
31433152
while fut2.key in b.tasks:
31443153
await asyncio.sleep(0)
31453154

3146-
assert b.tasks[fut1.key] != "flight"
3147-
log_before = list(b.log)
3148-
await Worker.gather_dep(b, **mocked_gather.call_args.kwargs)
3149-
assert log_before == list(b.log)
3155+
assert b.tasks[fut1.key].state == "cancelled"
3156+
3157+
args, kwargs = mocked_gather.call_args
3158+
await Worker.gather_dep(b, *args, **kwargs)
3159+
3160+
assert fut2.key not in b.tasks
3161+
f1_story = b.story(fut1.key)
3162+
assert f1_story
3163+
assert not any("missing-dep" in msg for msg in b.story(fut2.key))
3164+
3165+
3166+
@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
3167+
@pytest.mark.parametrize("close_worker", [False, True])
3168+
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
3169+
async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
3170+
c, s, a, b, x, intermediate_state, close_worker
3171+
):
3172+
"""If a task was transitioned to in-flight, the gather-dep coroutine was
3173+
scheduled but a cancel request came in before gather_data_from_worker was
3174+
issued this might corrupt the state machine if the cancelled key is not
3175+
properly handled"""
3176+
3177+
fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1")
3178+
fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B")
3179+
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
3180+
await fut2
3181+
with mock.patch.object(distributed.worker.Worker, "gather_dep") as mocked_gather:
3182+
fut3 = c.submit(inc, fut2, workers=[b.address], key="f3")
3183+
3184+
fut2_key = fut2.key
3185+
3186+
await _wait_for_state(fut2_key, b, "flight")
3187+
3188+
s.set_restrictions(worker={fut1B.key: a.address, fut2.key: b.address})
3189+
while not mocked_gather.call_args:
3190+
await asyncio.sleep(0)
3191+
3192+
await s.remove_worker(address=x.address, safe=True, close=close_worker)
3193+
3194+
await _wait_for_state(fut2_key, b, intermediate_state)
3195+
3196+
args, kwargs = mocked_gather.call_args
3197+
await Worker.gather_dep(b, *args, **kwargs)
3198+
await fut3

0 commit comments

Comments
 (0)