26
26
from collections import defaultdict
27
27
from datetime import datetime , timedelta
28
28
from functools import partial
29
+ from inspect import currentframe
29
30
from tempfile import NamedTemporaryFile
31
+ from types import FrameType
30
32
from typing import IO , TYPE_CHECKING , Any , Iterable , List , NamedTuple , Optional , Tuple , Union
31
33
from urllib .parse import quote
34
+ from weakref import WeakKeyDictionary
32
35
33
36
import dill
34
37
import jinja2
115
118
from airflow .models .baseoperator import BaseOperator
116
119
from airflow .models .dag import DAG , DagModel , DagRun
117
120
121
+ _EXECUTION_FRAME_MAPPING : "WeakKeyDictionary[Operator, FrameType]" = WeakKeyDictionary ()
122
+
118
123
119
124
@contextlib .contextmanager
120
125
def set_current_context (context : Context ) -> None :
@@ -1409,7 +1414,18 @@ def _execute_task_with_callbacks(self, context):
1409
1414
"""Prepare Task for Execution"""
1410
1415
from airflow .models .renderedtifields import RenderedTaskInstanceFields
1411
1416
1417
+ parent_pid = os .getpid ()
1418
+
1412
1419
def signal_handler (signum , frame ):
1420
+ pid = os .getpid ()
1421
+
1422
+ # If a task forks during execution (from DAG code) for whatever
1423
+ # reason, we want to make sure that we react to the signal only in
1424
+ # the process that we've spawned ourselves (referred to here as the
1425
+ # parent process).
1426
+ if pid != parent_pid :
1427
+ os ._exit (1 )
1428
+ return
1413
1429
self .log .error ("Received SIGTERM. Terminating subprocesses." )
1414
1430
self .task .on_kill ()
1415
1431
raise AirflowException ("Task received SIGTERM signal" )
@@ -1480,7 +1496,6 @@ def _execute_task(self, context, task_copy):
1480
1496
# If the task has been deferred and is being executed due to a trigger,
1481
1497
# then we need to pick the right method to come back to, otherwise
1482
1498
# we go for the default execute
1483
- execute_callable = task_copy .execute
1484
1499
if self .next_method :
1485
1500
# __fail__ is a special signal value for next_method that indicates
1486
1501
# this task was scheduled specifically to fail.
@@ -1494,29 +1509,35 @@ def _execute_task(self, context, task_copy):
1494
1509
execute_callable = getattr (task_copy , self .next_method )
1495
1510
if self .next_kwargs :
1496
1511
execute_callable = partial (execute_callable , ** self .next_kwargs )
1512
+ else :
1513
+ execute_callable = task_copy .execute
1497
1514
# If a timeout is specified for the task, make it fail
1498
1515
# if it goes beyond
1499
- if task_copy .execution_timeout :
1500
- # If we are coming in with a next_method (i.e. from a deferral),
1501
- # calculate the timeout from our start_date.
1502
- if self .next_method :
1503
- timeout_seconds = (
1504
- task_copy .execution_timeout - (timezone .utcnow () - self .start_date )
1505
- ).total_seconds ()
1516
+ try :
1517
+ if task_copy .execution_timeout :
1518
+ # If we are coming in with a next_method (i.e. from a deferral),
1519
+ # calculate the timeout from our start_date.
1520
+ if self .next_method :
1521
+ timeout_seconds = (
1522
+ task_copy .execution_timeout - (timezone .utcnow () - self .start_date )
1523
+ ).total_seconds ()
1524
+ else :
1525
+ timeout_seconds = task_copy .execution_timeout .total_seconds ()
1526
+ try :
1527
+ # It's possible we're already timed out, so fast-fail if true
1528
+ if timeout_seconds <= 0 :
1529
+ raise AirflowTaskTimeout ()
1530
+ # Run task in timeout wrapper
1531
+ with timeout (timeout_seconds ):
1532
+ result = execute_callable (context = context )
1533
+ except AirflowTaskTimeout :
1534
+ task_copy .on_kill ()
1535
+ raise
1506
1536
else :
1507
- timeout_seconds = task_copy .execution_timeout .total_seconds ()
1508
- try :
1509
- # It's possible we're already timed out, so fast-fail if true
1510
- if timeout_seconds <= 0 :
1511
- raise AirflowTaskTimeout ()
1512
- # Run task in timeout wrapper
1513
- with timeout (timeout_seconds ):
1514
- result = execute_callable (context = context )
1515
- except AirflowTaskTimeout :
1516
- task_copy .on_kill ()
1517
- raise
1518
- else :
1519
- result = execute_callable (context = context )
1537
+ result = execute_callable (context = context )
1538
+ except : # noqa: E722
1539
+ _EXECUTION_FRAME_MAPPING [task_copy ] = currentframe ()
1540
+ raise
1520
1541
# If the task returns a result, push an XCom containing it
1521
1542
if task_copy .do_xcom_push and result is not None :
1522
1543
self .xcom_push (key = XCOM_RETURN_KEY , value = result )
@@ -1718,7 +1739,15 @@ def handle_failure(
1718
1739
1719
1740
if error :
1720
1741
if isinstance (error , Exception ):
1721
- self .log .error ("Task failed with exception" , exc_info = error )
1742
+ execution_frame = _EXECUTION_FRAME_MAPPING .get (self .task )
1743
+ tb = error .__traceback__
1744
+ while tb is not None :
1745
+ if tb .tb_frame is execution_frame :
1746
+ tb = tb .tb_next
1747
+ break
1748
+ tb = tb .tb_next
1749
+ tb = tb or error .__traceback__
1750
+ self .log .error ("Task failed with exception" , exc_info = (type (error ), error , tb ))
1722
1751
else :
1723
1752
self .log .error ("%s" , error )
1724
1753
# external monitoring process provides pickle file so _run_raw_task
0 commit comments