diff --git a/debug_toolbar/panels/sql/forms.py b/debug_toolbar/panels/sql/forms.py index 0515c5c8e..bb83155f4 100644 --- a/debug_toolbar/panels/sql/forms.py +++ b/debug_toolbar/panels/sql/forms.py @@ -5,7 +5,7 @@ from django.db import connections from django.utils.functional import cached_property -from debug_toolbar.panels.sql.utils import reformat_sql +from debug_toolbar.panels.sql.utils import is_select_query, reformat_sql class SQLSelectForm(forms.Form): @@ -27,7 +27,7 @@ class SQLSelectForm(forms.Form): def clean_raw_sql(self): value = self.cleaned_data["raw_sql"] - if not value.lower().strip().startswith("select"): + if not is_select_query(value): raise ValidationError("Only 'select' queries are allowed.") return value diff --git a/debug_toolbar/panels/sql/panel.py b/debug_toolbar/panels/sql/panel.py index 879be38b0..7be5c4da6 100644 --- a/debug_toolbar/panels/sql/panel.py +++ b/debug_toolbar/panels/sql/panel.py @@ -12,7 +12,11 @@ from debug_toolbar.panels.sql import views from debug_toolbar.panels.sql.forms import SQLSelectForm from debug_toolbar.panels.sql.tracking import wrap_cursor -from debug_toolbar.panels.sql.utils import contrasting_color_generator, reformat_sql +from debug_toolbar.panels.sql.utils import ( + contrasting_color_generator, + is_select_query, + reformat_sql, +) from debug_toolbar.utils import render_stacktrace @@ -266,9 +270,7 @@ def generate_stats(self, request, response): query["sql"] = reformat_sql(query["sql"], with_toggle=True) query["is_slow"] = query["duration"] > sql_warning_threshold - query["is_select"] = ( - query["raw_sql"].lower().lstrip().startswith("select") - ) + query["is_select"] = is_select_query(query["raw_sql"]) query["rgb_color"] = self._databases[alias]["rgb_color"] try: diff --git a/debug_toolbar/panels/sql/utils.py b/debug_toolbar/panels/sql/utils.py index cb4eda348..b8fd34afe 100644 --- a/debug_toolbar/panels/sql/utils.py +++ b/debug_toolbar/panels/sql/utils.py @@ -86,6 +86,11 @@ def process(stmt): return "".join(escaped_value(token) for token in stmt.flatten()) +def is_select_query(sql): + # UNION queries can start with "(". + return sql.lower().lstrip(" (").startswith("select") + + def reformat_sql(sql, *, with_toggle=False): formatted = parse_sql(sql) if not with_toggle: diff --git a/docs/changes.rst b/docs/changes.rst index e82c598c2..74e495372 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -3,6 +3,7 @@ Change log Pending ------- +* Support select and explain buttons for ``UNION`` queries on PostgreSQL. 4.4.6 (2024-07-10) ------------------ diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 332e9b1e8..8e105657b 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -729,6 +729,13 @@ def test_similar_and_duplicate_grouping(self): self.assertNotEqual(queries[0]["similar_color"], queries[3]["similar_color"]) self.assertNotEqual(queries[0]["duplicate_color"], queries[3]["similar_color"]) + def test_explain_with_union(self): + list(User.objects.filter(id__lt=20).union(User.objects.filter(id__gt=10))) + response = self.panel.process_request(self.request) + self.panel.generate_stats(self.request, response) + query = self.panel._queries[0] + self.assertTrue(query["is_select"]) + class SQLPanelMultiDBTestCase(BaseMultiDBTestCase): panel_id = "SQLPanel" diff --git a/tests/test_integration.py b/tests/test_integration.py index e6863e7a9..f3ea52b64 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -432,6 +432,29 @@ def test_sql_explain_checks_show_toolbar(self): ) self.assertEqual(response.status_code, 404) + @unittest.skipUnless( + connection.vendor == "postgresql", "Test valid only on PostgreSQL" + ) + def test_sql_explain_postgres_union_query(self): + """ + Confirm select queries that start with a parenthesis can be explained. + """ + url = "/__debug__/sql_explain/" + data = { + "signed": SignedDataForm.sign( + { + "sql": "(SELECT * FROM auth_user) UNION (SELECT * from auth_user)", + "raw_sql": "(SELECT * FROM auth_user) UNION (SELECT * from auth_user)", + "params": "{}", + "alias": "default", + "duration": "0", + } + ) + } + + response = self.client.post(url, data) + self.assertEqual(response.status_code, 200) + @unittest.skipUnless( connection.vendor == "postgresql", "Test valid only on PostgreSQL" )