Skip to content

Commit 5fd211c

Browse files
committed
[mlir][linalg] Implement TilingInterface for winograd operators
In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operators. Before converting winograd operators into nested loops with matrix multiply, tile the input of conv2d into the supported size first. Add a transform operator structured.decompose_winograd_op to decompose winograd operators. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operators.
1 parent 14ce369 commit 5fd211c

File tree

7 files changed

+758
-3
lines changed

7 files changed

+758
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157-
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[DeclareOpInterfaceMethods<TilingInterface,
159+
["getIterationDomain",
160+
"getLoopIteratorTypes",
161+
"getResultTilePosition",
162+
"getTiledImplementation"]>]> {
158163
let summary = "Winograd filter transform operator";
159164
let description = [{
160165
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -192,7 +197,12 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
192197
let hasVerifier = 1;
193198
}
194199

195-
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
200+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
201+
[DeclareOpInterfaceMethods<TilingInterface,
202+
["getIterationDomain",
203+
"getLoopIteratorTypes",
204+
"getResultTilePosition",
205+
"getTiledImplementation"]>]> {
196206
let summary = "Winograd input transform operator";
197207
let description = [{
198208
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -230,7 +240,12 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
230240
let hasVerifier = 1;
231241
}
232242

233-
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
243+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
244+
[DeclareOpInterfaceMethods<TilingInterface,
245+
["getIterationDomain",
246+
"getLoopIteratorTypes",
247+
"getResultTilePosition",
248+
"getTiledImplementation"]>]> {
234249
let summary = "Winograd output transform operator";
235250
let description = [{
236251
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
26382638
}];
26392639
}
26402640

2641+
def DecomposeWinogradOp : Op<Transform_Dialect,
2642+
"structured.decompose_winograd_op",
2643+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2644+
TransformOpInterface, TransformEachOpTrait,
2645+
ReportTrackingListenerFailuresOpTrait]> {
2646+
let description = [{
2647+
Decompose winograd operators. It will convert filter, input and output
2648+
transform operators into a combination of scf, tensor, and linalg
2649+
equivalent operators. Before applying this transform operator, users
2650+
need to tile winograd transform operators into supported sizes.
2651+
2652+
#### Return modes:
2653+
2654+
This operation fails if `target` is unsupported. Otherwise, the operation
2655+
succeeds and returns a handle of the sequence that replaces the original
2656+
operator.
2657+
}];
2658+
2659+
let arguments = (ins TransformHandleTypeInterface:$target);
2660+
let results = (outs TransformHandleTypeInterface:$transformed);
2661+
2662+
let assemblyFormat =
2663+
"$target attr-dict `:` functional-type($target, results)";
2664+
2665+
let builders = [
2666+
OpBuilder<(ins "Value":$target)>
2667+
];
2668+
2669+
let extraClassDeclaration = [{
2670+
::mlir::DiagnosedSilenceableFailure applyToOne(
2671+
::mlir::transform::TransformRewriter &rewriter,
2672+
::mlir::Operation *target,
2673+
::mlir::transform::ApplyToEachResultList &results,
2674+
::mlir::transform::TransformState &state);
2675+
}];
2676+
}
2677+
26412678
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13191319
linalg::Conv2DNhwcFhwcOp op, int64_t m,
13201320
int64_t r);
13211321

1322+
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1323+
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1324+
/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
1325+
/// C. After the rewriting, we get
1326+
///
1327+
/// scf.for %f = lo_f to hi_f step 1
1328+
/// scf.for %c = lo_c to hi_c step 1
1329+
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1330+
/// %ret = linalg.matmul G, %extracted
1331+
/// %ret = linalg.matmul %ret, GT
1332+
/// %inserted = insert %ret into filter<h x w x c x f>
1333+
FailureOr<Operation *>
1334+
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1335+
linalg::WinogradFilterTransformOp op);
1336+
1337+
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1338+
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1339+
/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
1340+
/// C. After the rewriting, we get
1341+
///
1342+
/// scf.for %n = lo_n to hi_n step 1
1343+
/// scf.for %c = lo_c to hi_c step 1
1344+
/// %extracted = extract input<h x w> from input<n x h x w x c>
1345+
/// %ret = linalg.matmul BT, %extracted
1346+
/// %ret = linalg.matmul %ret, B
1347+
/// %inserted = insert %ret into input<h x w x n x c>
1348+
FailureOr<Operation *>
1349+
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1350+
linalg::WinogradInputTransformOp op);
1351+
1352+
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1353+
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1354+
/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
1355+
/// F. After the transformation, we get
1356+
///
1357+
/// scf.for %n = lo_n to hi_n step 1
1358+
/// scf.for %f = lo_f to hi_f step 1
1359+
/// %extracted = extract input<h x w> from result<h x w x n x f>
1360+
/// %ret = linalg.matmul AT, %extracted
1361+
/// %ret = linalg.matmul %ret, A
1362+
/// %inserted = insert %ret into ret<n x h x w x f>
1363+
FailureOr<Operation *>
1364+
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1365+
linalg::WinogradOutputTransformOp op);
1366+
13221367
//===----------------------------------------------------------------------===//
13231368
// Rewrite patterns wrapping transformations.
13241369
// TODO: every single such pattern should be a close to noop wrapper around a

0 commit comments

Comments
 (0)