Skip to content

Commit 8debcf0

Browse files
author
Peiming Liu
authored
[mlir][sparse] introduce sparse_tensor.iterate operation (#88807)
A `sparse_tensor.iterate` iterates over a sparse iteration space extracted from `sparse_tensor.extract_iteration_space` operation introduced in #88554. *DO NOT MERGE* before #88554
1 parent 002297b commit 8debcf0

File tree

8 files changed

+882
-1
lines changed

8 files changed

+882
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/IR/TensorEncoding.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
22+
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

25+
#include "llvm/ADT/bit.h"
26+
2327
//===----------------------------------------------------------------------===//
2428
//
2529
// Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
4145
/// including the value `ShapedType::kDynamic` (for shapes).
4246
using Size = int64_t;
4347

48+
/// A simple wrapper to encode a bitset of defined (at most 64) levels.
49+
class LevelSet {
50+
uint64_t bits = 0;
51+
52+
public:
53+
LevelSet() = default;
54+
explicit LevelSet(uint64_t bits) : bits(bits) {}
55+
operator uint64_t() const { return bits; }
56+
57+
LevelSet &set(unsigned i) {
58+
assert(i < 64);
59+
bits |= 1 << i;
60+
return *this;
61+
}
62+
63+
LevelSet &operator|=(LevelSet lhs) {
64+
bits |= static_cast<uint64_t>(lhs);
65+
return *this;
66+
}
67+
68+
LevelSet &lshift(unsigned offset) {
69+
bits = bits << offset;
70+
return *this;
71+
}
72+
73+
bool operator[](unsigned i) const {
74+
assert(i < 64);
75+
return (bits & (1 << i)) != 0;
76+
}
77+
78+
unsigned count() const { return llvm::popcount(bits); }
79+
bool empty() const { return bits == 0; }
80+
};
81+
4482
} // namespace sparse_tensor
4583
} // namespace mlir
4684

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
1919
list<Trait> traits = []>
2020
: AttrDef<SparseTensor_Dialect, name, traits>;
2121

22+
//===----------------------------------------------------------------------===//
23+
// A simple bitset attribute wrapped over a single int64_t to encode a set of
24+
// sparse tensor levels.
25+
//===----------------------------------------------------------------------===//
26+
27+
def LevelSetAttr :
28+
TypedAttrBase<
29+
I64, "IntegerAttr",
30+
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
31+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
32+
"LevelSet attribute"> {
33+
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34+
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
35+
}
36+
2237
//===----------------------------------------------------------------------===//
2338
// These attributes are just like `IndexAttr` except that they clarify whether
2439
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/LoopLikeInterface.td"
1820

1921
//===----------------------------------------------------------------------===//
2022
// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12771279

12781280
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
12791281
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1280-
"ForeachOp"]>]>,
1282+
"ForeachOp", "IterateOp"]>]>,
12811283
Arguments<(ins Variadic<AnyType>:$results)> {
12821284
let summary = "Yield from sparse_tensor set-like operations";
12831285
let description = [{
@@ -1430,6 +1432,154 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14301432
let hasVerifier = 1;
14311433
}
14321434

1435+
//===----------------------------------------------------------------------===//
1436+
// Sparse Tensor Iteration Operations.
1437+
//===----------------------------------------------------------------------===//
1438+
1439+
def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1440+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1441+
1442+
let arguments = (ins AnySparseTensor:$tensor,
1443+
Optional<AnySparseIterator>:$parentIter,
1444+
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1445+
1446+
let results = (outs AnySparseIterSpace:$resultSpace);
1447+
1448+
let summary = "Extract an iteration space from a sparse tensor between certain levels";
1449+
let description = [{
1450+
Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1451+
certian (consecutive) levels.
1452+
1453+
`tensor`: the input sparse tensor that defines the iteration space.
1454+
`parentIter`: the iterator for the previous level, at which the iteration space
1455+
at the current levels will be extracted.
1456+
`loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1457+
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1458+
iteration space.
1459+
1460+
Example:
1461+
```mlir
1462+
// Extracts a 1-D iteration space from a COO tensor at level 1.
1463+
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1464+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1465+
```
1466+
}];
1467+
1468+
1469+
let extraClassDeclaration = [{
1470+
std::pair<Level, Level> getLvlRange() {
1471+
return std::make_pair(getLoLvl(), getHiLvl());
1472+
}
1473+
unsigned getSpaceDim() {
1474+
return getHiLvl() - getLoLvl();
1475+
}
1476+
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1477+
return getResultSpace().getType().getLvlTypes();
1478+
}
1479+
}];
1480+
1481+
let hasVerifier = 1;
1482+
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1483+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1484+
}
1485+
1486+
def IterateOp : SparseTensor_Op<"iterate",
1487+
[RecursiveMemoryEffects, RecursivelySpeculatable,
1488+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
1489+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1490+
"getYieldedValuesMutable"]>,
1491+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1492+
["getEntrySuccessorOperands"]>,
1493+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1494+
1495+
let arguments = (ins AnySparseIterSpace:$iterSpace,
1496+
Variadic<AnyType>:$initArgs,
1497+
LevelSetAttr:$crdUsedLvls);
1498+
let results = (outs Variadic<AnyType>:$results);
1499+
let regions = (region SizedRegion<1>:$region);
1500+
1501+
let summary = "Iterate over a sparse iteration space";
1502+
let description = [{
1503+
The `sparse_tensor.iterate` operations represents a loop over the
1504+
provided iteration space extracted from a specific sparse tensor.
1505+
The operation defines an SSA value for a sparse iterator that points
1506+
to the current stored element in the sparse tensor and SSA values
1507+
for coordinates of the stored element. The coordinates are always
1508+
converted to `index` type despite of the underlying sparse tensor
1509+
storage. When coordinates are not used, the SSA values can be skipped
1510+
by `_` symbols, which usually leads to simpler generated code after
1511+
sparsification. For example:
1512+
1513+
```mlir
1514+
// The coordinate for level 0 is not used when iterating over a 2-D
1515+
// iteration space.
1516+
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1517+
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1518+
```
1519+
1520+
`sparse_tensor.iterate` can also operate on loop-carried variables
1521+
and returns the final values after loop termination.
1522+
The initial values of the variables are passed as additional SSA operands
1523+
to the iterator SSA value and used coordinate SSA values mentioned
1524+
above. The operation region has an argument for the iterator, variadic
1525+
arguments for specified (used) coordiates and followed by one argument
1526+
for each loop-carried variable, representing the value of the variable
1527+
at the current iteration.
1528+
The body region must contain exactly one block that terminates with
1529+
`sparse_tensor.yield`.
1530+
1531+
`sparse_tensor.iterate` results hold the final values after the last
1532+
iteration. If the `sparse_tensor.iterate` defines any values, a yield
1533+
must be explicitly present.
1534+
The number and types of the `sparse_tensor.iterate` results must match
1535+
the initial values in the iter_args binding and the yield operands.
1536+
1537+
1538+
A nested `sparse_tensor.iterate` example that prints all the coordinates
1539+
stored in the sparse input:
1540+
1541+
```mlir
1542+
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1543+
// Iterates over the first level of %sp
1544+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1545+
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1546+
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1547+
// Iterates over the second level of %sp
1548+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1549+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1550+
%r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1551+
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1552+
vector.print %crd0 : index
1553+
vector.print %crd1 : index
1554+
}
1555+
}
1556+
}
1557+
1558+
```
1559+
}];
1560+
1561+
let extraClassDeclaration = [{
1562+
unsigned getSpaceDim() {
1563+
return getIterSpace().getType().getSpaceDim();
1564+
}
1565+
BlockArgument getIterator() {
1566+
return getRegion().getArguments().front();
1567+
}
1568+
Block::BlockArgListType getCrds() {
1569+
// The first block argument is iterator, the remaining arguments are
1570+
// referenced coordinates.
1571+
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1572+
}
1573+
unsigned getNumRegionIterArgs() {
1574+
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1575+
}
1576+
}];
1577+
1578+
let hasVerifier = 1;
1579+
let hasRegionVerifier = 1;
1580+
let hasCustomAssemblyFormat = 1;
1581+
}
1582+
14331583
//===----------------------------------------------------------------------===//
14341584
// Sparse Tensor Debugging and Test-Only Operations.
14351585
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
7272
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
7373
"::mlir::sparse_tensor::StorageSpecifierType">;
7474

75+
//===----------------------------------------------------------------------===//
76+
// Sparse Tensor Iteration Types.
77+
//===----------------------------------------------------------------------===//
78+
79+
def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
80+
let mnemonic = "iter_space";
81+
82+
let description = [{
83+
A sparse iteration space that represents an abstract N-D (sparse) iteration space
84+
extracted from a sparse tensor.
85+
86+
Examples:
87+
88+
```mlir
89+
// An iteration space extracted from a CSR tensor between levels [0, 2).
90+
!iter_space<#CSR, lvls = 0 to 2>
91+
```
92+
}];
93+
94+
let parameters = (ins
95+
SparseTensorEncodingAttr : $encoding,
96+
"Level" : $loLvl,
97+
"Level" : $hiLvl
98+
);
99+
100+
let extraClassDeclaration = [{
101+
/// The the dimension of the iteration space.
102+
unsigned getSpaceDim() const {
103+
return getHiLvl() - getLoLvl();
104+
}
105+
106+
/// Get the level types for the iteration space.
107+
ArrayRef<LevelType> getLvlTypes() const {
108+
return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
109+
}
110+
111+
/// Whether the iteration space is unique (i.e., no duplicated coordinate).
112+
bool isUnique() {
113+
return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
114+
}
115+
116+
/// Get the corresponding iterator type.
117+
::mlir::sparse_tensor::IteratorType getIteratorType() const;
118+
}];
119+
120+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
121+
}
122+
123+
def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
124+
let mnemonic = "iterator";
125+
126+
let description = [{
127+
An iterator that points to the current element in the corresponding iteration space.
128+
129+
Examples:
130+
131+
```mlir
132+
// An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
133+
!iterator<#CSR, lvls = 0 to 2>
134+
```
135+
}];
136+
137+
let parameters = (ins
138+
SparseTensorEncodingAttr : $encoding,
139+
"Level" : $loLvl,
140+
"Level" : $hiLvl
141+
);
142+
143+
let extraClassDeclaration = [{
144+
/// Get the corresponding iteration space type.
145+
::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
146+
147+
unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
148+
ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
149+
bool isUnique() { return getIterSpaceType().isUnique(); }
150+
}];
151+
152+
let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
153+
}
154+
155+
def IsSparseSparseIterSpaceTypePred
156+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
157+
158+
def IsSparseSparseIteratorTypePred
159+
: CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
160+
161+
def AnySparseIterSpace
162+
: Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
163+
"::mlir::sparse_tensor::IterSpaceType">;
164+
165+
def AnySparseIterator
166+
: Type<IsSparseSparseIteratorTypePred, "sparse iterator",
167+
"::mlir::sparse_tensor::IteratorType">;
168+
169+
75170
#endif // SPARSETENSOR_TYPES

0 commit comments

Comments
 (0)