Skip to content

Commit fd3b4e9

Browse files
maltheephraimbuddy
authored andcommitted
Truncate stack trace to DAG user code for exceptions raised during execution (#20731)
(cherry picked from commit 7ea0f76)
1 parent 372ae59 commit fd3b4e9

File tree

4 files changed

+106
-32
lines changed

4 files changed

+106
-32
lines changed

airflow/models/taskinstance.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
from collections import defaultdict
2727
from datetime import datetime, timedelta
2828
from functools import partial
29+
from inspect import currentframe
2930
from tempfile import NamedTemporaryFile
31+
from types import FrameType
3032
from typing import IO, TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
3133
from urllib.parse import quote
34+
from weakref import WeakKeyDictionary
3235

3336
import dill
3437
import jinja2
@@ -115,6 +118,8 @@
115118
from airflow.models.baseoperator import BaseOperator
116119
from airflow.models.dag import DAG, DagModel, DagRun
117120

121+
_EXECUTION_FRAME_MAPPING: "WeakKeyDictionary[Operator, FrameType]" = WeakKeyDictionary()
122+
118123

119124
@contextlib.contextmanager
120125
def set_current_context(context: Context) -> None:
@@ -1409,7 +1414,18 @@ def _execute_task_with_callbacks(self, context):
14091414
"""Prepare Task for Execution"""
14101415
from airflow.models.renderedtifields import RenderedTaskInstanceFields
14111416

1417+
parent_pid = os.getpid()
1418+
14121419
def signal_handler(signum, frame):
1420+
pid = os.getpid()
1421+
1422+
# If a task forks during execution (from DAG code) for whatever
1423+
# reason, we want to make sure that we react to the signal only in
1424+
# the process that we've spawned ourselves (referred to here as the
1425+
# parent process).
1426+
if pid != parent_pid:
1427+
os._exit(1)
1428+
return
14131429
self.log.error("Received SIGTERM. Terminating subprocesses.")
14141430
self.task.on_kill()
14151431
raise AirflowException("Task received SIGTERM signal")
@@ -1480,7 +1496,6 @@ def _execute_task(self, context, task_copy):
14801496
# If the task has been deferred and is being executed due to a trigger,
14811497
# then we need to pick the right method to come back to, otherwise
14821498
# we go for the default execute
1483-
execute_callable = task_copy.execute
14841499
if self.next_method:
14851500
# __fail__ is a special signal value for next_method that indicates
14861501
# this task was scheduled specifically to fail.
@@ -1494,29 +1509,35 @@ def _execute_task(self, context, task_copy):
14941509
execute_callable = getattr(task_copy, self.next_method)
14951510
if self.next_kwargs:
14961511
execute_callable = partial(execute_callable, **self.next_kwargs)
1512+
else:
1513+
execute_callable = task_copy.execute
14971514
# If a timeout is specified for the task, make it fail
14981515
# if it goes beyond
1499-
if task_copy.execution_timeout:
1500-
# If we are coming in with a next_method (i.e. from a deferral),
1501-
# calculate the timeout from our start_date.
1502-
if self.next_method:
1503-
timeout_seconds = (
1504-
task_copy.execution_timeout - (timezone.utcnow() - self.start_date)
1505-
).total_seconds()
1516+
try:
1517+
if task_copy.execution_timeout:
1518+
# If we are coming in with a next_method (i.e. from a deferral),
1519+
# calculate the timeout from our start_date.
1520+
if self.next_method:
1521+
timeout_seconds = (
1522+
task_copy.execution_timeout - (timezone.utcnow() - self.start_date)
1523+
).total_seconds()
1524+
else:
1525+
timeout_seconds = task_copy.execution_timeout.total_seconds()
1526+
try:
1527+
# It's possible we're already timed out, so fast-fail if true
1528+
if timeout_seconds <= 0:
1529+
raise AirflowTaskTimeout()
1530+
# Run task in timeout wrapper
1531+
with timeout(timeout_seconds):
1532+
result = execute_callable(context=context)
1533+
except AirflowTaskTimeout:
1534+
task_copy.on_kill()
1535+
raise
15061536
else:
1507-
timeout_seconds = task_copy.execution_timeout.total_seconds()
1508-
try:
1509-
# It's possible we're already timed out, so fast-fail if true
1510-
if timeout_seconds <= 0:
1511-
raise AirflowTaskTimeout()
1512-
# Run task in timeout wrapper
1513-
with timeout(timeout_seconds):
1514-
result = execute_callable(context=context)
1515-
except AirflowTaskTimeout:
1516-
task_copy.on_kill()
1517-
raise
1518-
else:
1519-
result = execute_callable(context=context)
1537+
result = execute_callable(context=context)
1538+
except: # noqa: E722
1539+
_EXECUTION_FRAME_MAPPING[task_copy] = currentframe()
1540+
raise
15201541
# If the task returns a result, push an XCom containing it
15211542
if task_copy.do_xcom_push and result is not None:
15221543
self.xcom_push(key=XCOM_RETURN_KEY, value=result)
@@ -1718,7 +1739,15 @@ def handle_failure(
17181739

17191740
if error:
17201741
if isinstance(error, Exception):
1721-
self.log.error("Task failed with exception", exc_info=error)
1742+
execution_frame = _EXECUTION_FRAME_MAPPING.get(self.task)
1743+
tb = error.__traceback__
1744+
while tb is not None:
1745+
if tb.tb_frame is execution_frame:
1746+
tb = tb.tb_next
1747+
break
1748+
tb = tb.tb_next
1749+
tb = tb or error.__traceback__
1750+
self.log.error("Task failed with exception", exc_info=(type(error), error, tb))
17221751
else:
17231752
self.log.error("%s", error)
17241753
# external monitoring process provides pickle file so _run_raw_task

airflow/task/task_runner/standard_task_runner.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,29 @@ def _start_by_fork(self):
7373
# [1:] - remove "airflow" from the start of the command
7474
args = parser.parse_args(self._command[1:])
7575

76+
# We prefer the job_id passed on the command-line because at this time, the
77+
# task instance may not have been updated.
78+
job_id = getattr(args, "job_id", self._task_instance.job_id)
7679
self.log.info('Running: %s', self._command)
77-
self.log.info('Job %s: Subtask %s', self._task_instance.job_id, self._task_instance.task_id)
80+
self.log.info('Job %s: Subtask %s', job_id, self._task_instance.task_id)
7881

7982
proc_title = "airflow task runner: {0.dag_id} {0.task_id} {0.execution_date_or_run_id}"
80-
if hasattr(args, "job_id"):
83+
if job_id is not None:
8184
proc_title += " {0.job_id}"
8285
setproctitle(proc_title.format(args))
8386

8487
try:
8588
args.func(args, dag=self.dag)
8689
return_code = 0
87-
except Exception:
90+
except Exception as exc:
8891
return_code = 1
89-
self.log.exception(
90-
"Failed to execute job %s for task %s",
91-
self._task_instance.job_id,
92+
93+
self.log.error(
94+
"Failed to execute job %s for task %s (%s; %r)",
95+
job_id,
9296
self._task_instance.task_id,
97+
exc,
98+
os.getpid(),
9399
)
94100
finally:
95101
# Explicitly flush any pending exception to Sentry if enabled

tests/models/test_taskinstance.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import datetime
2020
import os
2121
import signal
22+
import sys
2223
import urllib
2324
from tempfile import NamedTemporaryFile
25+
from traceback import format_exception
2426
from typing import List, Optional, Union, cast
2527
from unittest import mock
2628
from unittest.mock import call, mock_open, patch
@@ -1843,6 +1845,27 @@ def fail():
18431845
pass # expected
18441846
assert State.UP_FOR_RETRY == ti.state
18451847

1848+
def test_stacktrace_on_failure_starts_with_task_execute_method(self, dag_maker):
1849+
def fail():
1850+
raise AirflowException("maybe this will pass?")
1851+
1852+
with dag_maker(dag_id='test_retries_on_other_exceptions'):
1853+
task = PythonOperator(
1854+
task_id='test_raise_other_exception',
1855+
python_callable=fail,
1856+
retries=1,
1857+
)
1858+
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
1859+
ti.task = task
1860+
with patch.object(TI, "log") as log, pytest.raises(AirflowException):
1861+
ti.run()
1862+
assert len(log.error.mock_calls) == 1
1863+
assert log.error.call_args[0] == ("Task failed with exception",)
1864+
exc_info = log.error.call_args[1]["exc_info"]
1865+
filename = exc_info[2].tb_frame.f_code.co_filename
1866+
formatted_exc = format_exception(*exc_info)
1867+
assert sys.modules[PythonOperator.__module__].__file__ == filename, "".join(formatted_exc)
1868+
18461869
def _env_var_check_callback(self):
18471870
assert 'test_echo_env_variables' == os.environ['AIRFLOW_CTX_DAG_ID']
18481871
assert 'hive_in_python_op' == os.environ['AIRFLOW_CTX_TASK_ID']

tests/task/task_runner/test_standard_task_runner.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
# under the License.
1818
import logging
1919
import os
20+
import re
2021
import time
2122
from logging.config import dictConfig
23+
from tempfile import NamedTemporaryFile
2224
from unittest import mock
2325

2426
import psutil
@@ -40,11 +42,13 @@
4042

4143
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
4244

45+
TASK_FORMAT = '{{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s'
46+
4347
LOGGING_CONFIG = {
4448
'version': 1,
4549
'disable_existing_loggers': False,
4650
'formatters': {
47-
'airflow.task': {'format': '[%(asctime)s] {{%(filename)s:%(lineno)d}} %(levelname)s - %(message)s'},
51+
'airflow.task': {'format': TASK_FORMAT},
4852
},
4953
'handlers': {
5054
'console': {
@@ -197,19 +201,23 @@ def test_on_kill(self):
197201
dag = dagbag.dags.get('test_on_kill')
198202
task = dag.get_task('task1')
199203

200-
with create_session() as session:
204+
with create_session() as session, NamedTemporaryFile("w", delete=False) as f:
201205
dag.create_dagrun(
202206
run_id="test",
207+
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
203208
state=State.RUNNING,
204-
execution_date=DEFAULT_DATE,
205209
start_date=DEFAULT_DATE,
206210
session=session,
207211
)
208-
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
212+
ti = TaskInstance(task=task, run_id="test")
209213
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True)
210214
session.commit()
215+
ti.refresh_from_task(task)
211216

212217
runner = StandardTaskRunner(job1)
218+
handler = logging.StreamHandler(f)
219+
handler.setFormatter(logging.Formatter(TASK_FORMAT))
220+
runner.log.addHandler(handler)
213221
runner.start()
214222

215223
with timeout(seconds=3):
@@ -232,6 +240,14 @@ def test_on_kill(self):
232240
logging.info(f"Terminating processes {processes} belonging to {runner_pgid} group")
233241
runner.terminate()
234242
session.close() # explicitly close as `create_session`s commit will blow up otherwise
243+
with open(f.name) as g:
244+
logged = g.read()
245+
os.unlink(f.name)
246+
247+
ti.refresh_from_db()
248+
assert re.findall(r'ERROR - Failed to execute job (\S+) for task (\S+)', logged) == [
249+
(str(ti.job_id), ti.task_id)
250+
], logged
235251

236252
logging.info("Waiting for the on kill killed file to appear")
237253
with timeout(seconds=4):

0 commit comments

Comments
 (0)