diff --git a/dask_sql/physical/rel/logical/aggregate.py b/dask_sql/physical/rel/logical/aggregate.py index 171e20c79..cfaf1a76a 100644 --- a/dask_sql/physical/rel/logical/aggregate.py +++ b/dask_sql/physical/rel/logical/aggregate.py @@ -7,6 +7,11 @@ import dask.dataframe as dd import pandas as pd +try: + import dask_cudf +except ImportError: + dask_cudf = None + from dask_sql.datacontainer import ColumnContainer, DataContainer from dask_sql.physical.rel.base import BaseRelPlugin from dask_sql.physical.rex.core.call import IsNullOperation @@ -48,18 +53,42 @@ class AggregationSpecification: """ Most of the aggregations in SQL are already implemented 1:1 in dask and can just be called via their name - (e.g. AVG is the mean). However sometimes those already - implemented functions only work well for numerical - functions. This small container class therefore - can have an additional aggregation function, which is - valid for non-numerical types. + (e.g. AVG is the mean). However sometimes those + implemented functions only work well for some datatypes. + This small container class therefore + can have an custom aggregation function, which is + valid for not supported dtypes. """ - def __init__(self, numerical_aggregation, non_numerical_aggregation=None): - self.numerical_aggregation = numerical_aggregation - self.non_numerical_aggregation = ( - non_numerical_aggregation or numerical_aggregation - ) + def __init__(self, built_in_aggregation, custom_aggregation=None): + self.built_in_aggregation = built_in_aggregation + self.custom_aggregation = custom_aggregation or built_in_aggregation + + def get_supported_aggregation(self, series): + built_in_aggregation = self.built_in_aggregation + + # built-in aggregations work well for numeric types + if pd.api.types.is_numeric_dtype(series.dtype): + return built_in_aggregation + + # Todo: Add Categorical when support comes to dask-sql + if built_in_aggregation in ["min", "max"]: + if pd.api.types.is_datetime64_any_dtype(series.dtype): + return built_in_aggregation + + if pd.api.types.is_string_dtype(series.dtype): + # If dask_cudf strings dtype, return built-in aggregation + if dask_cudf is not None and isinstance(series, dask_cudf.Series): + return built_in_aggregation + + # With pandas StringDtype built-in aggregations work + # while with pandas ObjectDtype and Nulls built-in aggregations fail + if isinstance(series, dd.Series) and isinstance( + series.dtype, pd.StringDtype + ): + return built_in_aggregation + + return self.custom_aggregation class LogicalAggregatePlugin(BaseRelPlugin): @@ -303,13 +332,9 @@ def _collect_aggregations( f"Aggregation function {aggregation_name} not implemented (yet)." ) if isinstance(aggregation_function, AggregationSpecification): - dtype = df[input_col].dtype - if pd.api.types.is_numeric_dtype(dtype): - aggregation_function = aggregation_function.numerical_aggregation - else: - aggregation_function = ( - aggregation_function.non_numerical_aggregation - ) + aggregation_function = aggregation_function.get_supported_aggregation( + df[input_col] + ) # Finally, extract the output column name output_col = str(agg_call.getValue()) diff --git a/tests/integration/test_compatibility.py b/tests/integration/test_compatibility.py index 6a8a334cc..760857356 100644 --- a/tests/integration/test_compatibility.py +++ b/tests/integration/test_compatibility.py @@ -19,6 +19,14 @@ from dask_sql import Context +def cast_datetime_to_string(df): + cols = df.select_dtypes(include=["datetime64[ns]"]).columns + # Casting to object first as + # directly converting to string looses second precision + df[cols] = df[cols].astype("object").astype("string") + return df + + def eq_sqlite(sql, **dfs): c = Context() engine = sqlite3.connect(":memory:") @@ -30,6 +38,10 @@ def eq_sqlite(sql, **dfs): dask_result = c.sql(sql).compute().reset_index(drop=True) sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True) + # casting to object to ensure equality with sql-lite + # which returns object dtype for datetime inputs + dask_result = cast_datetime_to_string(dask_result) + # Make sure SQL and Dask use the same "NULL" value dask_result = dask_result.fillna(np.NaN) sqlite_result = sqlite_result.fillna(np.NaN) @@ -54,6 +66,11 @@ def make_rand_df(size: int, **kwargs): r = [f"ssssss{x}" for x in range(10)] c = np.random.randint(10, size=size) s = np.array([r[x] for x in c]) + elif dt is pd.StringDtype: + r = [f"ssssss{x}" for x in range(10)] + c = np.random.randint(10, size=size) + s = np.array([r[x] for x in c]) + s = pd.array(s, dtype="string") elif dt is datetime: rt = [datetime(2020, 1, 1) + timedelta(days=x) for x in range(10)] c = np.random.randint(10, size=size) @@ -337,7 +354,14 @@ def test_agg_sum_avg(): def test_agg_min_max_no_group_by(): a = make_rand_df( - 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + 100, + a=(int, 50), + b=(str, 50), + c=(int, 30), + d=(str, 40), + e=(float, 40), + f=(pd.StringDtype, 40), + g=(datetime, 40), ) eq_sqlite( """ @@ -352,6 +376,10 @@ def test_agg_min_max_no_group_by(): MAX(d) AS max_d, MIN(e) AS min_e, MAX(e) AS max_e, + MIN(f) as min_f, + MAX(f) as max_f, + MIN(g) as min_g, + MAX(g) as max_g, MIN(a+e) AS mix_1, MIN(a)+MIN(e) AS mix_2 FROM a @@ -362,7 +390,14 @@ def test_agg_min_max_no_group_by(): def test_agg_min_max(): a = make_rand_df( - 100, a=(int, 50), b=(str, 50), c=(int, 30), d=(str, 40), e=(float, 40) + 100, + a=(int, 50), + b=(str, 50), + c=(int, 30), + d=(str, 40), + e=(float, 40), + f=(pd.StringDtype, 40), + g=(datetime, 40), ) eq_sqlite( """ @@ -374,6 +409,10 @@ def test_agg_min_max(): MAX(d) AS max_d, MIN(e) AS min_e, MAX(e) AS max_e, + MIN(f) AS min_f, + MAX(f) AS max_f, + MIN(g) AS min_g, + MAX(g) AS max_g, MIN(a+e) AS mix_1, MIN(a)+MIN(e) AS mix_2 FROM a GROUP BY a, b