Skip to content

Commit 7d7776f

Browse files
selection: normal: implement flattening (#270)
Summary: adds `FlatteningRules` to flatten nested unions/intersections and fluent rule composition via `RewriteRuleExt.then()`. updates `normalize()` to apply flattening before identity rules. Pull Request resolved: #270 Reviewed By: mariusae Differential Revision: D76664976 Pulled By: shayne-fletcher fbshipit-source-id: d2f7a2d0af63221e224699171e54c6ab02104ce4
1 parent 11f45bb commit 7d7776f

File tree

2 files changed

+136
-1
lines changed

2 files changed

+136
-1
lines changed

ndslice/src/selection/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ use serde::Serialize;
109109

110110
use crate::Slice;
111111
use crate::selection::normal::NormalizedSelection;
112+
use crate::selection::normal::RewriteRuleExt;
112113
use crate::shape;
113114
use crate::shape::ShapeError;
114115
use crate::slice::SliceError;
@@ -358,7 +359,7 @@ pub fn structurally_equal(a: &Selection, b: &Selection) -> bool {
358359
/// structure. It is designed to improve over time as additional
359360
/// rewrites (e.g., flattening, simplification) are introduced.
360361
pub fn normalize(sel: &Selection) -> NormalizedSelection {
361-
let rule = normal::IdentityRules;
362+
let rule = normal::FlatteningRules.then(normal::IdentityRules);
362363
sel.fold::<normal::NormalizedSelection>()
363364
.rewrite_bottom_up(&rule)
364365
}

ndslice/src/selection/normal.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,27 @@ pub trait RewriteRule {
119119
fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection;
120120
}
121121

122+
impl<R1: RewriteRule, R2: RewriteRule> RewriteRule for (R1, R2) {
123+
fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
124+
self.1.rewrite(self.0.rewrite(node))
125+
}
126+
}
127+
128+
/// Extension trait for composing rewrite rules in a fluent style.
129+
///
130+
/// This trait provides a `then` method that allows chaining rewrite
131+
/// rules together, creating a pipeline where rules are applied
132+
/// left-to-right.
133+
pub trait RewriteRuleExt: RewriteRule + Sized {
134+
/// Chains this rule with another rule, creating a composite rule
135+
/// that applies `self` first, then `other`.
136+
fn then<R: RewriteRule>(self, other: R) -> (Self, R) {
137+
(self, other)
138+
}
139+
}
140+
141+
impl<T: RewriteRule> RewriteRuleExt for T {}
142+
122143
impl From<NormalizedSelection> for Selection {
123144
/// Converts the normalized form back into a standard `Selection`.
124145
///
@@ -204,6 +225,53 @@ impl RewriteRule for IdentityRules {
204225
}
205226
}
206227

228+
/// A normalization rule that flattens nested unions and
229+
/// intersections.
230+
#[derive(Default)]
231+
pub struct FlatteningRules;
232+
233+
impl RewriteRule for FlatteningRules {
234+
// Flattening rewrites:
235+
//
236+
// - Union(a, Union(b, c)) → Union(a, b, c) // flatten nested unions
237+
// - Intersection(a, Intersection(b, c)) → Intersection(a, b, c) // flatten nested intersections
238+
fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
239+
use NormalizedSelection::*;
240+
241+
match node {
242+
Union(set) => {
243+
let mut flattened = BTreeSet::new();
244+
for item in set {
245+
match item {
246+
Union(inner_set) => {
247+
flattened.extend(inner_set);
248+
}
249+
other => {
250+
flattened.insert(other);
251+
}
252+
}
253+
}
254+
Union(flattened)
255+
}
256+
Intersection(set) => {
257+
let mut flattened = BTreeSet::new();
258+
for item in set {
259+
match item {
260+
Intersection(inner_set) => {
261+
flattened.extend(inner_set);
262+
}
263+
other => {
264+
flattened.insert(other);
265+
}
266+
}
267+
}
268+
Intersection(flattened)
269+
}
270+
_ => node,
271+
}
272+
}
273+
}
274+
207275
impl NormalizedSelection {
208276
pub fn rewrite_bottom_up(self, rule: &impl RewriteRule) -> Self {
209277
let mapped = self.trav(|child| child.rewrite_bottom_up(rule));
@@ -261,4 +329,70 @@ mod tests {
261329

262330
assert_structurally_eq!(&normed.into(), &expected);
263331
}
332+
333+
#[test]
334+
fn test_union_flattening() {
335+
use NormalizedSelection::*;
336+
337+
// Create Union(a, Union(b, c)) manually
338+
let inner_union = {
339+
let mut set = BTreeSet::new();
340+
set.insert(All(Box::new(True))); // represents 'b'
341+
set.insert(Any(Box::new(True))); // represents 'c'
342+
Union(set)
343+
};
344+
345+
let outer_union = {
346+
let mut set = BTreeSet::new();
347+
set.insert(First(Box::new(True))); // represents 'a'
348+
set.insert(inner_union);
349+
Union(set)
350+
};
351+
352+
let rule = FlatteningRules;
353+
let result = rule.rewrite(outer_union);
354+
355+
// Should be flattened to Union(a, b, c)
356+
if let Union(set) = result {
357+
assert_eq!(set.len(), 3);
358+
assert!(set.contains(&First(Box::new(True))));
359+
assert!(set.contains(&All(Box::new(True))));
360+
assert!(set.contains(&Any(Box::new(True))));
361+
} else {
362+
panic!("Expected Union, got {:?}", result);
363+
}
364+
}
365+
366+
#[test]
367+
fn test_intersection_flattening() {
368+
use NormalizedSelection::*;
369+
370+
// Create Intersection(a, Intersection(b, c)) manually
371+
let inner_intersection = {
372+
let mut set = BTreeSet::new();
373+
set.insert(All(Box::new(True))); // represents 'b'
374+
set.insert(Any(Box::new(True))); // represents 'c'
375+
Intersection(set)
376+
};
377+
378+
let outer_intersection = {
379+
let mut set = BTreeSet::new();
380+
set.insert(First(Box::new(True))); // represents 'a'
381+
set.insert(inner_intersection);
382+
Intersection(set)
383+
};
384+
385+
let rule = FlatteningRules;
386+
let result = rule.rewrite(outer_intersection);
387+
388+
// Should be flattened to Intersection(a, b, c)
389+
if let Intersection(set) = result {
390+
assert_eq!(set.len(), 3);
391+
assert!(set.contains(&First(Box::new(True))));
392+
assert!(set.contains(&All(Box::new(True))));
393+
assert!(set.contains(&Any(Box::new(True))));
394+
} else {
395+
panic!("Expected Intersection, got {:?}", result);
396+
}
397+
}
264398
}

0 commit comments

Comments
 (0)