|
7 | 7 | import dask.dataframe as dd
|
8 | 8 | import pandas as pd
|
9 | 9 |
|
| 10 | +try: |
| 11 | + import dask_cudf |
| 12 | +except ImportError: |
| 13 | + dask_cudf = None |
| 14 | + |
10 | 15 | from dask_sql.datacontainer import ColumnContainer, DataContainer
|
11 | 16 | from dask_sql.physical.rel.base import BaseRelPlugin
|
12 | 17 | from dask_sql.physical.rex.core.call import IsNullOperation
|
@@ -48,18 +53,42 @@ class AggregationSpecification:
|
48 | 53 | """
|
49 | 54 | Most of the aggregations in SQL are already
|
50 | 55 | implemented 1:1 in dask and can just be called via their name
|
51 |
| - (e.g. AVG is the mean). However sometimes those already |
52 |
| - implemented functions only work well for numerical |
53 |
| - functions. This small container class therefore |
54 |
| - can have an additional aggregation function, which is |
55 |
| - valid for non-numerical types. |
| 56 | + (e.g. AVG is the mean). However sometimes those |
| 57 | + implemented functions only work well for some datatypes. |
| 58 | + This small container class therefore |
| 59 | + can have an custom aggregation function, which is |
| 60 | + valid for not supported dtypes. |
56 | 61 | """
|
57 | 62 |
|
58 |
| - def __init__(self, numerical_aggregation, non_numerical_aggregation=None): |
59 |
| - self.numerical_aggregation = numerical_aggregation |
60 |
| - self.non_numerical_aggregation = ( |
61 |
| - non_numerical_aggregation or numerical_aggregation |
62 |
| - ) |
| 63 | + def __init__(self, built_in_aggregation, custom_aggregation=None): |
| 64 | + self.built_in_aggregation = built_in_aggregation |
| 65 | + self.custom_aggregation = custom_aggregation or built_in_aggregation |
| 66 | + |
| 67 | + def get_supported_aggregation(self, series): |
| 68 | + built_in_aggregation = self.built_in_aggregation |
| 69 | + |
| 70 | + # built-in aggregations work well for numeric types |
| 71 | + if pd.api.types.is_numeric_dtype(series.dtype): |
| 72 | + return built_in_aggregation |
| 73 | + |
| 74 | + # Todo: Add Categorical when support comes to dask-sql |
| 75 | + if built_in_aggregation in ["min", "max"]: |
| 76 | + if pd.api.types.is_datetime64_any_dtype(series.dtype): |
| 77 | + return built_in_aggregation |
| 78 | + |
| 79 | + if pd.api.types.is_string_dtype(series.dtype): |
| 80 | + # If dask_cudf strings dtype, return built-in aggregation |
| 81 | + if dask_cudf is not None and isinstance(series, dask_cudf.Series): |
| 82 | + return built_in_aggregation |
| 83 | + |
| 84 | + # With pandas StringDtype built-in aggregations work |
| 85 | + # while with pandas ObjectDtype and Nulls built-in aggregations fail |
| 86 | + if isinstance(series, dd.Series) and isinstance( |
| 87 | + series.dtype, pd.StringDtype |
| 88 | + ): |
| 89 | + return built_in_aggregation |
| 90 | + |
| 91 | + return self.custom_aggregation |
63 | 92 |
|
64 | 93 |
|
65 | 94 | class LogicalAggregatePlugin(BaseRelPlugin):
|
@@ -303,13 +332,9 @@ def _collect_aggregations(
|
303 | 332 | f"Aggregation function {aggregation_name} not implemented (yet)."
|
304 | 333 | )
|
305 | 334 | if isinstance(aggregation_function, AggregationSpecification):
|
306 |
| - dtype = df[input_col].dtype |
307 |
| - if pd.api.types.is_numeric_dtype(dtype): |
308 |
| - aggregation_function = aggregation_function.numerical_aggregation |
309 |
| - else: |
310 |
| - aggregation_function = ( |
311 |
| - aggregation_function.non_numerical_aggregation |
312 |
| - ) |
| 335 | + aggregation_function = aggregation_function.get_supported_aggregation( |
| 336 | + df[input_col] |
| 337 | + ) |
313 | 338 |
|
314 | 339 | # Finally, extract the output column name
|
315 | 340 | output_col = str(agg_call.getValue())
|
|
0 commit comments