@@ -1804,14 +1804,14 @@ async def test_story_with_deps(c, s, a, b):
1804
1804
stimulus_ids .add (msg [- 2 ])
1805
1805
pruned_story .append (tuple (pruned_msg [:- 2 ]))
1806
1806
1807
- assert len (stimulus_ids ) == 3
1807
+ assert len (stimulus_ids ) == 3 , stimulus_ids
1808
1808
stimulus_id = pruned_story [0 ][- 1 ]
1809
1809
assert isinstance (stimulus_id , str )
1810
1810
assert stimulus_id .startswith ("compute-task" )
1811
1811
# This is a simple transition log
1812
1812
expected_story = [
1813
1813
(key , "compute-task" ),
1814
- (key , "released" , "waiting" , {}),
1814
+ (key , "released" , "waiting" , {dep . key : "fetch" }),
1815
1815
(key , "waiting" , "ready" , {}),
1816
1816
(key , "ready" , "executing" , {}),
1817
1817
(key , "put-in-memory" ),
@@ -1832,11 +1832,11 @@ async def test_story_with_deps(c, s, a, b):
1832
1832
stimulus_ids .add (msg [- 2 ])
1833
1833
pruned_story .append (tuple (pruned_msg [:- 2 ]))
1834
1834
1835
- assert len (stimulus_ids ) == 3
1835
+ assert len (stimulus_ids ) == 2 , stimulus_ids
1836
1836
stimulus_id = pruned_story [0 ][- 1 ]
1837
1837
assert isinstance (stimulus_id , str )
1838
1838
expected_story = [
1839
- (dep_story , "register-replica " , "released" ),
1839
+ (dep_story , "ensure-task-exists " , "released" ),
1840
1840
(dep_story , "released" , "fetch" , {}),
1841
1841
(
1842
1842
"gather-dependencies" ,
@@ -2794,7 +2794,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b):
2794
2794
_acquire_replicas (s , b , fut )
2795
2795
2796
2796
await futC
2797
- while fut .key not in b .tasks :
2797
+ while fut .key not in b .tasks or not b . tasks [ fut . key ]. state == "memory" :
2798
2798
await asyncio .sleep (0.005 )
2799
2799
assert len (s .who_has [fut .key ]) == 2
2800
2800
@@ -3082,12 +3082,14 @@ def clear_leak():
3082
3082
]
3083
3083
3084
3084
3085
- async def _wait_for_flight (key , worker ):
3086
- while key not in worker .tasks or worker .tasks [key ].state != "flight" :
3085
+ async def _wait_for_state (key : str , worker : Worker , state : str ):
3086
+ # Keep the sleep interval at 0 since the tests using this are very sensitive
3087
+ # about timing. they intend to capture loop cycles after this specific
3088
+ # condition was set
3089
+ while key not in worker .tasks or worker .tasks [key ].state != state :
3087
3090
await asyncio .sleep (0 )
3088
3091
3089
3092
3090
- @pytest .mark .xfail (reason = "#5406" )
3091
3093
@gen_cluster (client = True )
3092
3094
async def test_gather_dep_do_not_handle_response_of_not_requested_tasks (c , s , a , b ):
3093
3095
"""At time of writing, the gather_dep implementation filtered tasks again
@@ -3107,21 +3109,26 @@ async def test_gather_dep_do_not_handle_response_of_not_requested_tasks(c, s, a,
3107
3109
3108
3110
fut2_key = fut2 .key
3109
3111
3110
- await _wait_for_flight (fut2_key , b )
3112
+ await _wait_for_state (fut2_key , b , "flight" )
3113
+ while not mocked_gather .call_args :
3114
+ await asyncio .sleep (0 )
3111
3115
3112
3116
fut4 .release ()
3113
3117
while fut4 .key in b .tasks :
3114
3118
await asyncio .sleep (0 )
3115
3119
3116
- story_before = b .story (fut2 .key )
3117
- assert fut2 .key in mocked_gather .call_args .kwargs ["to_gather" ]
3118
- await Worker .gather_dep (b , ** mocked_gather .call_args .kwargs )
3119
- story_after = b .story (fut2 .key )
3120
- assert story_before == story_after
3120
+ assert b .tasks [fut2 .key ].state == "cancelled"
3121
+ args , kwargs = mocked_gather .call_args
3122
+ assert fut2 .key in kwargs ["to_gather" ]
3123
+
3124
+ await Worker .gather_dep (b , * args , ** kwargs )
3125
+ assert fut2 .key not in b .tasks
3126
+ f2_story = b .story (fut2 .key )
3127
+ assert f2_story
3128
+ assert not any ("missing-dep" in msg for msg in b .story (fut2 .key ))
3121
3129
await fut3
3122
3130
3123
3131
3124
- @pytest .mark .xfail (reason = "#5406" )
3125
3132
@gen_cluster (
3126
3133
client = True ,
3127
3134
config = {
@@ -3137,13 +3144,55 @@ async def test_gather_dep_no_longer_in_flight_tasks(c, s, a, b):
3137
3144
3138
3145
fut1_key = fut1 .key
3139
3146
3140
- await _wait_for_flight (fut1_key , b )
3147
+ await _wait_for_state (fut1_key , b , "flight" )
3148
+ while not mocked_gather .call_args :
3149
+ await asyncio .sleep (0 )
3141
3150
3142
3151
fut2 .release ()
3143
3152
while fut2 .key in b .tasks :
3144
3153
await asyncio .sleep (0 )
3145
3154
3146
- assert b .tasks [fut1 .key ] != "flight"
3147
- log_before = list (b .log )
3148
- await Worker .gather_dep (b , ** mocked_gather .call_args .kwargs )
3149
- assert log_before == list (b .log )
3155
+ assert b .tasks [fut1 .key ].state == "cancelled"
3156
+
3157
+ args , kwargs = mocked_gather .call_args
3158
+ await Worker .gather_dep (b , * args , ** kwargs )
3159
+
3160
+ assert fut2 .key not in b .tasks
3161
+ f1_story = b .story (fut1 .key )
3162
+ assert f1_story
3163
+ assert not any ("missing-dep" in msg for msg in b .story (fut2 .key ))
3164
+
3165
+
3166
+ @pytest .mark .parametrize ("intermediate_state" , ["resumed" , "cancelled" ])
3167
+ @pytest .mark .parametrize ("close_worker" , [False , True ])
3168
+ @gen_cluster (client = True , nthreads = [("" , 1 )] * 3 )
3169
+ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker (
3170
+ c , s , a , b , x , intermediate_state , close_worker
3171
+ ):
3172
+ """If a task was transitioned to in-flight, the gather-dep coroutine was
3173
+ scheduled but a cancel request came in before gather_data_from_worker was
3174
+ issued this might corrupt the state machine if the cancelled key is not
3175
+ properly handled"""
3176
+
3177
+ fut1 = c .submit (slowinc , 1 , workers = [a .address ], key = "f1" )
3178
+ fut1B = c .submit (slowinc , 2 , workers = [x .address ], key = "f1B" )
3179
+ fut2 = c .submit (sum , [fut1 , fut1B ], workers = [x .address ], key = "f2" )
3180
+ await fut2
3181
+ with mock .patch .object (distributed .worker .Worker , "gather_dep" ) as mocked_gather :
3182
+ fut3 = c .submit (inc , fut2 , workers = [b .address ], key = "f3" )
3183
+
3184
+ fut2_key = fut2 .key
3185
+
3186
+ await _wait_for_state (fut2_key , b , "flight" )
3187
+
3188
+ s .set_restrictions (worker = {fut1B .key : a .address , fut2 .key : b .address })
3189
+ while not mocked_gather .call_args :
3190
+ await asyncio .sleep (0 )
3191
+
3192
+ await s .remove_worker (address = x .address , safe = True , close = close_worker )
3193
+
3194
+ await _wait_for_state (fut2_key , b , intermediate_state )
3195
+
3196
+ args , kwargs = mocked_gather .call_args
3197
+ await Worker .gather_dep (b , * args , ** kwargs )
3198
+ await fut3
0 commit comments