Skip to content

Commit 5187224

Browse files
committed
[mlir][linalg] Implement TilingInterface for winograd operators
In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operators. Before converting winograd operators into nested loops with matrix multiply, tile the input of conv2d into the supported size first. Add a transform operator structured.decompose_winograd_op to decompose winograd operators. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operators.
1 parent 6a55b00 commit 5187224

File tree

7 files changed

+595
-3
lines changed

7 files changed

+595
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,12 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157-
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[DeclareOpInterfaceMethods<TilingInterface,
159+
["getIterationDomain",
160+
"getLoopIteratorTypes",
161+
"getResultTilePosition",
162+
"getTiledImplementation"]>]> {
158163
let summary = "Winograd filter transform operator";
159164
let description = [{
160165
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -191,7 +196,12 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
191196
}];
192197
}
193198

194-
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
199+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
200+
[DeclareOpInterfaceMethods<TilingInterface,
201+
["getIterationDomain",
202+
"getLoopIteratorTypes",
203+
"getResultTilePosition",
204+
"getTiledImplementation"]>]> {
195205
let summary = "Winograd input transform operator";
196206
let description = [{
197207
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
@@ -228,7 +238,12 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
228238
}];
229239
}
230240

231-
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
241+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
242+
[DeclareOpInterfaceMethods<TilingInterface,
243+
["getIterationDomain",
244+
"getLoopIteratorTypes",
245+
"getResultTilePosition",
246+
"getTiledImplementation"]>]> {
232247
let summary = "Winograd output transform operator";
233248
let description = [{
234249
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2638,4 +2638,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
26382638
}];
26392639
}
26402640

2641+
def DecomposeWinogradOp : Op<Transform_Dialect,
2642+
"structured.decompose_winograd_op",
2643+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2644+
TransformOpInterface, TransformEachOpTrait,
2645+
ReportTrackingListenerFailuresOpTrait]> {
2646+
let description = [{
2647+
Decompose winograd operators. It will convert filter, input and output
2648+
transform operators into a combination of scf, tensor, and linalg
2649+
equivalent operators. Before applying this transform operator, users
2650+
need to tile winograd transform operators into supported sizes.
2651+
2652+
#### Return modes:
2653+
2654+
This operation fails if `target` is unsupported. Otherwise, the operation
2655+
succeeds and returns a handle of the sequence that replaces the original
2656+
operator.
2657+
}];
2658+
2659+
let arguments = (ins TransformHandleTypeInterface:$target);
2660+
let results = (outs TransformHandleTypeInterface:$transformed);
2661+
2662+
let assemblyFormat =
2663+
"$target attr-dict `:` functional-type($target, results)";
2664+
2665+
let builders = [
2666+
OpBuilder<(ins "Value":$target)>
2667+
];
2668+
2669+
let extraClassDeclaration = [{
2670+
::mlir::DiagnosedSilenceableFailure applyToOne(
2671+
::mlir::transform::TransformRewriter &rewriter,
2672+
::mlir::Operation *target,
2673+
::mlir::transform::ApplyToEachResultList &results,
2674+
::mlir::transform::TransformState &state);
2675+
}];
2676+
}
2677+
26412678
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,51 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
13191319
linalg::Conv2DNhwcFhwcOp op, int64_t m,
13201320
int64_t r);
13211321

1322+
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
1323+
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
1324+
/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
1325+
/// C. After the rewriting, we get
1326+
///
1327+
/// scf.for %f = lo_f to hi_f step 1
1328+
/// scf.for %c = lo_c to hi_c step 1
1329+
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
1330+
/// %ret = linalg.matmul G, %extracted
1331+
/// %ret = linalg.matmul %ret, GT
1332+
/// %inserted = insert %ret into filter<h x w x c x f>
1333+
FailureOr<Operation *>
1334+
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
1335+
linalg::WinogradFilterTransformOp op);
1336+
1337+
/// Rewrite linalg.winograd_input_transform. The data layout of the input is
1338+
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
1339+
/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
1340+
/// C. After the rewriting, we get
1341+
///
1342+
/// scf.for %n = lo_n to hi_n step 1
1343+
/// scf.for %c = lo_c to hi_c step 1
1344+
/// %extracted = extract input<h x w> from input<n x h x w x c>
1345+
/// %ret = linalg.matmul BT, %extracted
1346+
/// %ret = linalg.matmul %ret, B
1347+
/// %inserted = insert %ret into input<h x w x n x c>
1348+
FailureOr<Operation *>
1349+
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
1350+
linalg::WinogradInputTransformOp op);
1351+
1352+
/// Rewrite linalg.winograd_output_transform. The data layout of the output is
1353+
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
1354+
/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
1355+
/// F. After the transformation, we get
1356+
///
1357+
/// scf.for %n = lo_n to hi_n step 1
1358+
/// scf.for %f = lo_f to hi_f step 1
1359+
/// %extracted = extract input<h x w> from result<h x w x n x f>
1360+
/// %ret = linalg.matmul AT, %extracted
1361+
/// %ret = linalg.matmul %ret, A
1362+
/// %inserted = insert %ret into ret<n x h x w x f>
1363+
FailureOr<Operation *>
1364+
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
1365+
linalg::WinogradOutputTransformOp op);
1366+
13221367
//===----------------------------------------------------------------------===//
13231368
// Rewrite patterns wrapping transformations.
13241369
// TODO: every single such pattern should be a close to noop wrapper around a

0 commit comments

Comments
 (0)