Skip to content

Support metadata on scalar values #16053

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions datafusion/core/src/execution/session_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,10 @@ impl SimplifyInfo for SessionSimplifyProvider<'_> {
fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result<DataType> {
expr.get_type(self.df_schema)
}

fn get_schema(&self) -> Option<&DFSchema> {
Some(self.df_schema)
}
}

#[derive(Debug)]
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,7 +2238,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast from u8 to i64 for literal will be simplified, and get lit(int64(5))
// the cast here is implicit so has CastOptions with safe=true
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }";
let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5), metadata: None }, fail_on_overflow: false }";
assert!(format!("{exec_plan:?}").contains(expected));
Ok(())
}
Expand All @@ -2263,7 +2263,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#;

assert_eq!(format!("{cube:?}"), expected);

Expand All @@ -2290,7 +2290,7 @@ mod tests {
&session_state,
);

let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;
let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL), metadata: None }, "c1"), (Literal { value: Int64(NULL), metadata: None }, "c2"), (Literal { value: Int64(NULL), metadata: None }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#;

assert_eq!(format!("{rollup:?}"), expected);

Expand Down Expand Up @@ -2474,7 +2474,7 @@ mod tests {
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.

let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }";
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\"), metadata: None }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\"), metadata: None }, fail_on_overflow: false }, fail_on_overflow: false }";

let actual = format!("{execution_plan:?}");
assert!(actual.contains(expected), "{}", actual);
Expand Down
6 changes: 5 additions & 1 deletion datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Field, Schema};
use chrono::{DateTime, TimeZone, Utc};
use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*};
use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::builder::table_scan_with_filters;
Expand Down Expand Up @@ -71,6 +71,10 @@ impl SimplifyInfo for MyInfo {
fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(self.schema.as_ref())
}

fn get_schema(&self) -> Option<&DFSchema> {
Some(self.schema.as_ref())
}
}

impl From<DFSchemaRef> for MyInfo {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array};
use arrow::array::{
builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array,
Int32Array, RecordBatch, StringArray,
Expand Down Expand Up @@ -1527,6 +1527,54 @@ async fn test_metadata_based_udf() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_metadata_based_udf_with_literal() -> Result<()> {
let ctx = SessionContext::new();
let df = ctx.sql("select 0;").await?.select(vec![
lit(5u64).alias_with_metadata(
"lit_with_doubling",
Some(
[("modify_values".to_string(), "double_output".to_string())]
.into_iter()
.collect(),
),
),
lit(5u64).alias("lit_no_doubling"),
])?;

let custom_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new()));

let plan = LogicalPlanBuilder::from(df.into_optimized_plan()?)
.project(vec![
custom_udf
.call(vec![col("lit_with_doubling")])
.alias("doubled_output"),
custom_udf
.call(vec![col("lit_no_doubling")])
.alias("not_doubled_output"),
])?
.build()?;

let actual = DataFrame::new(ctx.state(), plan).collect().await?;

let schema = Arc::new(Schema::new(vec![
Field::new("doubled_output", DataType::UInt64, true),
Field::new("not_doubled_output", DataType::UInt64, false),
]));

let expected = RecordBatch::try_new(
schema,
vec![
create_array!(UInt64, [Some(10)]) as ArrayRef,
create_array!(UInt64, [5]),
],
)?;

assert_eq!(expected, actual[0]);

Ok(())
}

/// This UDF is to test extension handling, both on the input and output
/// sides. For the input, we will handle the data differently if there is
/// the canonical extension type Bool8. For the output we will add a
Expand Down
12 changes: 10 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,8 +1512,16 @@ impl Expr {
|expr| {
// f_up: unalias on up so we can remove nested aliases like
// `(x as foo) as bar`
if let Expr::Alias(Alias { expr, .. }) = expr {
Ok(Transformed::yes(*expr))
if let Expr::Alias(alias) = expr {
match alias
.metadata
.as_ref()
.map(|h| h.is_empty())
.unwrap_or(true)
{
true => Ok(Transformed::yes(*alias.expr)),
false => Ok(Transformed::no(Expr::Alias(alias))),
}
} else {
Ok(Transformed::no(expr))
}
Expand Down
9 changes: 8 additions & 1 deletion datafusion/expr/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Structs and traits to provide the information needed for expression simplification.

use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, DataFusionError, Result};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};

use crate::{execution_props::ExecutionProps, Expr, ExprSchemable};

Expand All @@ -40,6 +40,9 @@ pub trait SimplifyInfo {

/// Returns data type of this expr needed for determining optimized int type of a value
fn get_data_type(&self, expr: &Expr) -> Result<DataType>;

/// Returns the Schema which may be needed to evalute an Expr
fn get_schema(&self) -> Option<&DFSchema>;
}

/// Provides simplification information based on DFSchema and
Expand Down Expand Up @@ -106,6 +109,10 @@ impl SimplifyInfo for SimplifyContext<'_> {
fn execution_props(&self) -> &ExecutionProps {
self.props
}

fn get_schema(&self) -> Option<&DFSchema> {
self.schema.as_ref().map(|v| v.as_ref())
}
}

/// Was the expression simplified?
Expand Down
14 changes: 12 additions & 2 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,18 @@ fn is_expr_trivial(expr: &Expr) -> bool {
fn rewrite_expr(expr: Expr, input: &Projection) -> Result<Transformed<Expr>> {
expr.transform_up(|expr| {
match expr {
// remove any intermediate aliases
Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)),
// remove any intermediate aliases if they do not carry metadata
Expr::Alias(alias) => {
match alias
.metadata
.as_ref()
.map(|h| h.is_empty())
.unwrap_or(true)
{
true => Ok(Transformed::yes(*alias.expr)),
false => Ok(Transformed::no(Expr::Alias(alias))),
}
}
Expr::Column(col) => {
// Find index of column:
let idx = input.schema.index_of_column(&col)?;
Expand Down
27 changes: 22 additions & 5 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
/// Ok(DataType::Int32)
/// }
/// fn get_schema(&self) -> Option<&DFSchema> {
/// None
/// }
/// }
///
/// // Create the simplifier
Expand Down Expand Up @@ -227,12 +230,17 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
mut expr: Expr,
) -> Result<(Transformed<Expr>, u32)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;

let empty_schema = DFSchema::empty();
let mut const_evaluator = ConstEvaluator::try_new(
self.info.execution_props(),
self.info.get_schema().unwrap_or(&empty_schema),
)?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);

if self.canonicalize {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
expr = expr.rewrite(&mut Canonicalizer::new()).data()?;
}

// Evaluating constants can enable new simplifications and
Expand All @@ -249,14 +257,17 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;

// Track if any transformation occurred
has_transformed = has_transformed || transformed;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}

// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;

Ok((
Transformed::new_transformed(expr, has_transformed),
num_cycles,
Expand Down Expand Up @@ -586,12 +597,16 @@ impl<'a> ConstEvaluator<'a> {
/// Create a new `ConstantEvaluator`. Session constants (such as
/// the time for `now()` are taken from the passed
/// `execution_props`.
pub fn try_new(execution_props: &'a ExecutionProps) -> Result<Self> {
pub fn try_new(
execution_props: &'a ExecutionProps,
input_schema: &DFSchema,
) -> Result<Self> {
// The dummy column name is unused and doesn't matter as only
// expressions without column references can be evaluated
static DUMMY_COL_NAME: &str = ".";

let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
let input_schema = DFSchema::try_from(schema.clone())?;
let input_schema = input_schema.clone();
// Need a single "input" row to produce a single output row
let col = new_null_array(&DataType::Null, 1);
let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col])?;
Expand Down Expand Up @@ -639,8 +654,10 @@ impl<'a> ConstEvaluator<'a> {
Expr::ScalarFunction(ScalarFunction { func, .. }) => {
Self::volatility_ok(func.signature().volatility)
}
Expr::Alias(datafusion_expr::expr::Alias { metadata, .. }) => {
metadata.as_ref().map(|h| h.is_empty()).unwrap_or(true)
}
Expr::Literal(_)
| Expr::Alias(..)
| Expr::Unnest(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/expressions/dynamic_filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,15 +342,15 @@ mod test {
)
.unwrap();
let snap = dynamic_filter_1.snapshot().unwrap().unwrap();
insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#);
insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42), metadata: None }, fail_on_overflow: false }"#);
let dynamic_filter_2 = reassign_predicate_columns(
Arc::clone(&dynamic_filter) as Arc<dyn PhysicalExpr>,
&filter_schema_2,
false,
)
.unwrap();
let snap = dynamic_filter_2.snapshot().unwrap().unwrap();
insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"#);
insta::assert_snapshot!(format!("{snap:?}"), @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42), metadata: None }, fail_on_overflow: false }"#);
// Both filters allow evaluating the same expression
let batch_1 = RecordBatch::try_new(
Arc::clone(&filter_schema_1),
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1451,31 +1451,31 @@ mod tests {
let sql_string = fmt_sql(expr.as_ref()).to_string();
let display_string = expr.to_string();
assert_eq!(sql_string, "a IN (a, b)");
assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])");
assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }])");

// Test: a NOT IN ('a', 'b')
let list = vec![lit("a"), lit("b")];
let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
let sql_string = fmt_sql(expr.as_ref()).to_string();
let display_string = expr.to_string();
assert_eq!(sql_string, "a NOT IN (a, b)");
assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }])");
assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }])");

// Test: a IN ('a', 'b', NULL)
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
let expr = in_list(Arc::clone(&col_a), list, &false, &schema)?;
let sql_string = fmt_sql(expr.as_ref()).to_string();
let display_string = expr.to_string();
assert_eq!(sql_string, "a IN (a, b, NULL)");
assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])");
assert_eq!(display_string, "Use a@0 IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }, Literal { value: Utf8(NULL), metadata: None }])");

// Test: a NOT IN ('a', 'b', NULL)
let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))];
let expr = in_list(Arc::clone(&col_a), list, &true, &schema)?;
let sql_string = fmt_sql(expr.as_ref()).to_string();
let display_string = expr.to_string();
assert_eq!(sql_string, "a NOT IN (a, b, NULL)");
assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"b\") }, Literal { value: Utf8(NULL) }])");
assert_eq!(display_string, "a@0 NOT IN (SET) ([Literal { value: Utf8(\"a\"), metadata: None }, Literal { value: Utf8(\"b\"), metadata: None }, Literal { value: Utf8(NULL), metadata: None }])");

Ok(())
}
Expand Down
Loading
Loading