Skip to content

Commit 2868ba3

Browse files
[REVIEW] Fast path when possible for non numeric aggregation (#236)
* added extra check for datetime * cleaned_up_code * added check for string * added check for cudfDtype and PandasDtype * fixed preference for native vs custom agg * removed first from agg * added StringDtype to test_compatibility.py * added better comment for why we take a different code path with StringsDtype * added test for datetime * changed native to built in * made docstring cleaner * removed unused import * Trigger Build * Minor typo Co-authored-by: Charles Blackmon-Luca <[email protected]>
1 parent e63990c commit 2868ba3

File tree

2 files changed

+83
-19
lines changed

2 files changed

+83
-19
lines changed

dask_sql/physical/rel/logical/aggregate.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
import dask.dataframe as dd
88
import pandas as pd
99

10+
try:
11+
import dask_cudf
12+
except ImportError:
13+
dask_cudf = None
14+
1015
from dask_sql.datacontainer import ColumnContainer, DataContainer
1116
from dask_sql.physical.rel.base import BaseRelPlugin
1217
from dask_sql.physical.rex.core.call import IsNullOperation
@@ -48,18 +53,42 @@ class AggregationSpecification:
4853
"""
4954
Most of the aggregations in SQL are already
5055
implemented 1:1 in dask and can just be called via their name
51-
(e.g. AVG is the mean). However sometimes those already
52-
implemented functions only work well for numerical
53-
functions. This small container class therefore
54-
can have an additional aggregation function, which is
55-
valid for non-numerical types.
56+
(e.g. AVG is the mean). However sometimes those
57+
implemented functions only work well for some datatypes.
58+
This small container class therefore
59+
can have an custom aggregation function, which is
60+
valid for not supported dtypes.
5661
"""
5762

58-
def __init__(self, numerical_aggregation, non_numerical_aggregation=None):
59-
self.numerical_aggregation = numerical_aggregation
60-
self.non_numerical_aggregation = (
61-
non_numerical_aggregation or numerical_aggregation
62-
)
63+
def __init__(self, built_in_aggregation, custom_aggregation=None):
64+
self.built_in_aggregation = built_in_aggregation
65+
self.custom_aggregation = custom_aggregation or built_in_aggregation
66+
67+
def get_supported_aggregation(self, series):
68+
built_in_aggregation = self.built_in_aggregation
69+
70+
# built-in aggregations work well for numeric types
71+
if pd.api.types.is_numeric_dtype(series.dtype):
72+
return built_in_aggregation
73+
74+
# Todo: Add Categorical when support comes to dask-sql
75+
if built_in_aggregation in ["min", "max"]:
76+
if pd.api.types.is_datetime64_any_dtype(series.dtype):
77+
return built_in_aggregation
78+
79+
if pd.api.types.is_string_dtype(series.dtype):
80+
# If dask_cudf strings dtype, return built-in aggregation
81+
if dask_cudf is not None and isinstance(series, dask_cudf.Series):
82+
return built_in_aggregation
83+
84+
# With pandas StringDtype built-in aggregations work
85+
# while with pandas ObjectDtype and Nulls built-in aggregations fail
86+
if isinstance(series, dd.Series) and isinstance(
87+
series.dtype, pd.StringDtype
88+
):
89+
return built_in_aggregation
90+
91+
return self.custom_aggregation
6392

6493

6594
class LogicalAggregatePlugin(BaseRelPlugin):
@@ -303,13 +332,9 @@ def _collect_aggregations(
303332
f"Aggregation function {aggregation_name} not implemented (yet)."
304333
)
305334
if isinstance(aggregation_function, AggregationSpecification):
306-
dtype = df[input_col].dtype
307-
if pd.api.types.is_numeric_dtype(dtype):
308-
aggregation_function = aggregation_function.numerical_aggregation
309-
else:
310-
aggregation_function = (
311-
aggregation_function.non_numerical_aggregation
312-
)
335+
aggregation_function = aggregation_function.get_supported_aggregation(
336+
df[input_col]
337+
)
313338

314339
# Finally, extract the output column name
315340
output_col = str(agg_call.getValue())

tests/integration/test_compatibility.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@
1919
from dask_sql import Context
2020

2121

22+
def cast_datetime_to_string(df):
23+
cols = df.select_dtypes(include=["datetime64[ns]"]).columns
24+
# Casting to object first as
25+
# directly converting to string looses second precision
26+
df[cols] = df[cols].astype("object").astype("string")
27+
return df
28+
29+
2230
def eq_sqlite(sql, **dfs):
2331
c = Context()
2432
engine = sqlite3.connect(":memory:")
@@ -30,6 +38,10 @@ def eq_sqlite(sql, **dfs):
3038
dask_result = c.sql(sql).compute().reset_index(drop=True)
3139
sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True)
3240

41+
# casting to object to ensure equality with sql-lite
42+
# which returns object dtype for datetime inputs
43+
dask_result = cast_datetime_to_string(dask_result)
44+
3345
# Make sure SQL and Dask use the same "NULL" value
3446
dask_result = dask_result.fillna(np.NaN)
3547
sqlite_result = sqlite_result.fillna(np.NaN)
@@ -54,6 +66,11 @@ def make_rand_df(size: int, **kwargs):
5466
r = [f"ssssss{x}" for x in range(10)]
5567
c = np.random.randint(10, size=size)
5668
s = np.array([r[x] for x in c])
69+
elif dt is pd.StringDtype:
70+
r = [f"ssssss{x}" for x in range(10)]
71+
c = np.random.randint(10, size=size)
72+
s = np.array([r[x] for x in c])
73+
s = pd.array(s, dtype="string")
5774
elif dt is datetime:
5875
rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)]
5976
c = np.random.randint(10, size=size)
@@ -337,7 +354,14 @@ def test_agg_sum_avg():
337354

338355
def test_agg_min_max_no_group_by():
339356
a = make_rand_df(
340-
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
357+
100,
358+
a=(int, 50),
359+
b=(str, 50),
360+
c=(int, 30),
361+
d=(str, 40),
362+
e=(float, 40),
363+
f=(pd.StringDtype, 40),
364+
g=(datetime, 40),
341365
)
342366
eq_sqlite(
343367
"""
@@ -352,6 +376,10 @@ def test_agg_min_max_no_group_by():
352376
MAX(d) AS max_d,
353377
MIN(e) AS min_e,
354378
MAX(e) AS max_e,
379+
MIN(f) as min_f,
380+
MAX(f) as max_f,
381+
MIN(g) as min_g,
382+
MAX(g) as max_g,
355383
MIN(a+e) AS mix_1,
356384
MIN(a)+MIN(e) AS mix_2
357385
FROM a
@@ -362,7 +390,14 @@ def test_agg_min_max_no_group_by():
362390

363391
def test_agg_min_max():
364392
a = make_rand_df(
365-
100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40)
393+
100,
394+
a=(int, 50),
395+
b=(str, 50),
396+
c=(int, 30),
397+
d=(str, 40),
398+
e=(float, 40),
399+
f=(pd.StringDtype, 40),
400+
g=(datetime, 40),
366401
)
367402
eq_sqlite(
368403
"""
@@ -374,6 +409,10 @@ def test_agg_min_max():
374409
MAX(d) AS max_d,
375410
MIN(e) AS min_e,
376411
MAX(e) AS max_e,
412+
MIN(f) AS min_f,
413+
MAX(f) AS max_f,
414+
MIN(g) AS min_g,
415+
MAX(g) AS max_g,
377416
MIN(a+e) AS mix_1,
378417
MIN(a)+MIN(e) AS mix_2
379418
FROM a GROUP BY a, b

0 commit comments

Comments
 (0)