@@ -119,6 +119,27 @@ pub trait RewriteRule {
119
119
fn rewrite ( & self , node : NormalizedSelection ) -> NormalizedSelection ;
120
120
}
121
121
122
+ impl < R1 : RewriteRule , R2 : RewriteRule > RewriteRule for ( R1 , R2 ) {
123
+ fn rewrite ( & self , node : NormalizedSelection ) -> NormalizedSelection {
124
+ self . 1 . rewrite ( self . 0 . rewrite ( node) )
125
+ }
126
+ }
127
+
128
+ /// Extension trait for composing rewrite rules in a fluent style.
129
+ ///
130
+ /// This trait provides a `then` method that allows chaining rewrite
131
+ /// rules together, creating a pipeline where rules are applied
132
+ /// left-to-right.
133
+ pub trait RewriteRuleExt : RewriteRule + Sized {
134
+ /// Chains this rule with another rule, creating a composite rule
135
+ /// that applies `self` first, then `other`.
136
+ fn then < R : RewriteRule > ( self , other : R ) -> ( Self , R ) {
137
+ ( self , other)
138
+ }
139
+ }
140
+
141
+ impl < T : RewriteRule > RewriteRuleExt for T { }
142
+
122
143
impl From < NormalizedSelection > for Selection {
123
144
/// Converts the normalized form back into a standard `Selection`.
124
145
///
@@ -204,6 +225,53 @@ impl RewriteRule for IdentityRules {
204
225
}
205
226
}
206
227
228
+ /// A normalization rule that flattens nested unions and
229
+ /// intersections.
230
+ #[ derive( Default ) ]
231
+ pub struct FlatteningRules ;
232
+
233
+ impl RewriteRule for FlatteningRules {
234
+ // Flattening rewrites:
235
+ //
236
+ // - Union(a, Union(b, c)) → Union(a, b, c) // flatten nested unions
237
+ // - Intersection(a, Intersection(b, c)) → Intersection(a, b, c) // flatten nested intersections
238
+ fn rewrite ( & self , node : NormalizedSelection ) -> NormalizedSelection {
239
+ use NormalizedSelection :: * ;
240
+
241
+ match node {
242
+ Union ( set) => {
243
+ let mut flattened = BTreeSet :: new ( ) ;
244
+ for item in set {
245
+ match item {
246
+ Union ( inner_set) => {
247
+ flattened. extend ( inner_set) ;
248
+ }
249
+ other => {
250
+ flattened. insert ( other) ;
251
+ }
252
+ }
253
+ }
254
+ Union ( flattened)
255
+ }
256
+ Intersection ( set) => {
257
+ let mut flattened = BTreeSet :: new ( ) ;
258
+ for item in set {
259
+ match item {
260
+ Intersection ( inner_set) => {
261
+ flattened. extend ( inner_set) ;
262
+ }
263
+ other => {
264
+ flattened. insert ( other) ;
265
+ }
266
+ }
267
+ }
268
+ Intersection ( flattened)
269
+ }
270
+ _ => node,
271
+ }
272
+ }
273
+ }
274
+
207
275
impl NormalizedSelection {
208
276
pub fn rewrite_bottom_up ( self , rule : & impl RewriteRule ) -> Self {
209
277
let mapped = self . trav ( |child| child. rewrite_bottom_up ( rule) ) ;
@@ -261,4 +329,70 @@ mod tests {
261
329
262
330
assert_structurally_eq ! ( & normed. into( ) , & expected) ;
263
331
}
332
+
333
+ #[ test]
334
+ fn test_union_flattening ( ) {
335
+ use NormalizedSelection :: * ;
336
+
337
+ // Create Union(a, Union(b, c)) manually
338
+ let inner_union = {
339
+ let mut set = BTreeSet :: new ( ) ;
340
+ set. insert ( All ( Box :: new ( True ) ) ) ; // represents 'b'
341
+ set. insert ( Any ( Box :: new ( True ) ) ) ; // represents 'c'
342
+ Union ( set)
343
+ } ;
344
+
345
+ let outer_union = {
346
+ let mut set = BTreeSet :: new ( ) ;
347
+ set. insert ( First ( Box :: new ( True ) ) ) ; // represents 'a'
348
+ set. insert ( inner_union) ;
349
+ Union ( set)
350
+ } ;
351
+
352
+ let rule = FlatteningRules ;
353
+ let result = rule. rewrite ( outer_union) ;
354
+
355
+ // Should be flattened to Union(a, b, c)
356
+ if let Union ( set) = result {
357
+ assert_eq ! ( set. len( ) , 3 ) ;
358
+ assert ! ( set. contains( & First ( Box :: new( True ) ) ) ) ;
359
+ assert ! ( set. contains( & All ( Box :: new( True ) ) ) ) ;
360
+ assert ! ( set. contains( & Any ( Box :: new( True ) ) ) ) ;
361
+ } else {
362
+ panic ! ( "Expected Union, got {:?}" , result) ;
363
+ }
364
+ }
365
+
366
+ #[ test]
367
+ fn test_intersection_flattening ( ) {
368
+ use NormalizedSelection :: * ;
369
+
370
+ // Create Intersection(a, Intersection(b, c)) manually
371
+ let inner_intersection = {
372
+ let mut set = BTreeSet :: new ( ) ;
373
+ set. insert ( All ( Box :: new ( True ) ) ) ; // represents 'b'
374
+ set. insert ( Any ( Box :: new ( True ) ) ) ; // represents 'c'
375
+ Intersection ( set)
376
+ } ;
377
+
378
+ let outer_intersection = {
379
+ let mut set = BTreeSet :: new ( ) ;
380
+ set. insert ( First ( Box :: new ( True ) ) ) ; // represents 'a'
381
+ set. insert ( inner_intersection) ;
382
+ Intersection ( set)
383
+ } ;
384
+
385
+ let rule = FlatteningRules ;
386
+ let result = rule. rewrite ( outer_intersection) ;
387
+
388
+ // Should be flattened to Intersection(a, b, c)
389
+ if let Intersection ( set) = result {
390
+ assert_eq ! ( set. len( ) , 3 ) ;
391
+ assert ! ( set. contains( & First ( Box :: new( True ) ) ) ) ;
392
+ assert ! ( set. contains( & All ( Box :: new( True ) ) ) ) ;
393
+ assert ! ( set. contains( & Any ( Box :: new( True ) ) ) ) ;
394
+ } else {
395
+ panic ! ( "Expected Intersection, got {:?}" , result) ;
396
+ }
397
+ }
264
398
}
0 commit comments