@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
15
15
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
16
16
include "mlir/Interfaces/InferTypeOpInterface.td"
17
17
include "mlir/Interfaces/SideEffectInterfaces.td"
18
+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19
+ include "mlir/Interfaces/LoopLikeInterface.td"
18
20
19
21
//===----------------------------------------------------------------------===//
20
22
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
1304
1306
1305
1307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1306
1308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307
- "ForeachOp"]>]> {
1309
+ "ForeachOp", "IterateOp" ]>]> {
1308
1310
let summary = "Yield from sparse_tensor set-like operations";
1309
1311
let description = [{
1310
1312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1513
1515
let hasVerifier = 1;
1514
1516
}
1515
1517
1518
+ def IterateOp : SparseTensor_Op<"iterate",
1519
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1520
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1521
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1522
+ "getYieldedValuesMutable"]>,
1523
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1524
+ ["getEntrySuccessorOperands"]>,
1525
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1526
+
1527
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1528
+ Variadic<AnyType>:$initArgs,
1529
+ LevelSetAttr:$crdUsedLvls);
1530
+ let results = (outs Variadic<AnyType>:$results);
1531
+ let regions = (region SizedRegion<1>:$region);
1532
+
1533
+ let summary = "Iterate over a sparse iteration space";
1534
+ let description = [{
1535
+ The `sparse_tensor.iterate` operations represents a loop over the
1536
+ provided iteration space extracted from a specific sparse tensor.
1537
+ The operation defines an SSA value for a sparse iterator that points
1538
+ to the current stored element in the sparse tensor and SSA values
1539
+ for coordinates of the stored element. The coordinates are always
1540
+ converted to `index` type despite of the underlying sparse tensor
1541
+ storage. When coordinates are not used, the SSA values can be skipped
1542
+ by `_` symbols, which usually leads to simpler generated code after
1543
+ sparsification. For example:
1544
+
1545
+ ```mlir
1546
+ // The coordinate for level 0 is not used when iterating over a 2-D
1547
+ // iteration space.
1548
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1549
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1550
+ ```
1551
+
1552
+ `sparse_tensor.iterate` can also operate on loop-carried variables
1553
+ and returns the final values after loop termination.
1554
+ The initial values of the variables are passed as additional SSA operands
1555
+ to the iterator SSA value and used coordinate SSA values mentioned
1556
+ above. The operation region has an argument for the iterator, variadic
1557
+ arguments for specified (used) coordiates and followed by one argument
1558
+ for each loop-carried variable, representing the value of the variable
1559
+ at the current iteration.
1560
+ The body region must contain exactly one block that terminates with
1561
+ `sparse_tensor.yield`.
1562
+
1563
+ `sparse_tensor.iterate` results hold the final values after the last
1564
+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
1565
+ must be explicitly present.
1566
+ The number and types of the `sparse_tensor.iterate` results must match
1567
+ the initial values in the iter_args binding and the yield operands.
1568
+
1569
+
1570
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1571
+ stored in the sparse input:
1572
+
1573
+ ```mlir
1574
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1575
+ // Iterates over the first level of %sp
1576
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1577
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1578
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1579
+ // Iterates over the second level of %sp
1580
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1581
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1582
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1583
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1584
+ vector.print %crd0 : index
1585
+ vector.print %crd1 : index
1586
+ }
1587
+ }
1588
+ }
1589
+
1590
+ ```
1591
+ }];
1592
+
1593
+ let extraClassDeclaration = [{
1594
+ unsigned getSpaceDim() {
1595
+ return getIterSpace().getType().getSpaceDim();
1596
+ }
1597
+ BlockArgument getIterator() {
1598
+ return getRegion().getArguments().front();
1599
+ }
1600
+ Block::BlockArgListType getCrds() {
1601
+ // The first block argument is iterator, the remaining arguments are
1602
+ // referenced coordinates.
1603
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1604
+ }
1605
+ unsigned getNumRegionIterArgs() {
1606
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1607
+ }
1608
+ }];
1609
+
1610
+ let hasVerifier = 1;
1611
+ let hasRegionVerifier = 1;
1612
+ let hasCustomAssemblyFormat = 1;
1613
+ }
1614
+
1516
1615
//===----------------------------------------------------------------------===//
1517
1616
// Sparse Tensor Debugging and Test-Only Operations.
1518
1617
//===----------------------------------------------------------------------===//
0 commit comments