Skip to content

Commit 7727abb

Browse files
Hsiangkaiaaryanshukla
authored andcommitted
[mlir][linalg] Add transform operator for Winograd Conv2D algorithm (llvm#96182)
Add a transform operation structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operations. Reviewers: ftynse, Max191, GeorgeARM, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin Reviewed By: ftynse, Max191 Pull Request: llvm#96182
1 parent e78c6ac commit 7727abb

File tree

5 files changed

+173
-1
lines changed

5 files changed

+173
-1
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,4 +2646,55 @@ def MapCopyToThreadsOp :
26462646
}];
26472647
}
26482648

2649+
//===----------------------------------------------------------------------===//
2650+
// Winograd Conv2D
2651+
//===----------------------------------------------------------------------===//
2652+
2653+
def WinogradConv2DOp : Op<Transform_Dialect,
2654+
"structured.winograd_conv2d",
2655+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2656+
TransformOpInterface, TransformEachOpTrait,
2657+
ReportTrackingListenerFailuresOpTrait]> {
2658+
let description = [{
2659+
Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
2660+
matrix multiply. Before the matrix multiply, it will convert filter and
2661+
input into a format suitable for batched matrix multiply. After the matrix
2662+
multiply, it will convert output to the final result tensor.
2663+
2664+
The algorithm F(m x m, r x r) is
2665+
2666+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
2667+
2668+
The size of output Y is m x m. The size of filter g is r x r. The size of
2669+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
2670+
transformation matrices.
2671+
2672+
#### Return modes:
2673+
2674+
This operation produces a silenceable failure if `target` is unsupported.
2675+
Otherwise, the operation succeeds and returns a handle of the sequence that
2676+
replaces the original convolution.
2677+
}];
2678+
2679+
let arguments = (ins TransformHandleTypeInterface:$target,
2680+
I64Attr:$m,
2681+
I64Attr:$r);
2682+
let results = (outs TransformHandleTypeInterface:$transformed);
2683+
2684+
let assemblyFormat =
2685+
"$target attr-dict `:` functional-type($target, results)";
2686+
2687+
let builders = [
2688+
OpBuilder<(ins "Value":$target)>
2689+
];
2690+
2691+
let extraClassDeclaration = [{
2692+
::mlir::DiagnosedSilenceableFailure applyToOne(
2693+
::mlir::transform::TransformRewriter &rewriter,
2694+
::mlir::linalg::LinalgOp target,
2695+
::mlir::transform::ApplyToEachResultList &results,
2696+
::mlir::transform::TransformState &state);
2697+
}];
2698+
}
2699+
26492700
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13321332
linalg::BatchMatmulOp op,
13331333
bool transposeLHS = true);
13341334

1335+
/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1336+
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1337+
/// size of filter.
1338+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1339+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1340+
int64_t r);
1341+
13351342
//===----------------------------------------------------------------------===//
13361343
// Rewrite patterns wrapping transformations.
13371344
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3711,6 +3711,37 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
37113711
return DiagnosedSilenceableFailure::success();
37123712
}
37133713

3714+
//===----------------------------------------------------------------------===//
3715+
// WinogradConv2DOp
3716+
//===----------------------------------------------------------------------===//
3717+
3718+
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3719+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3720+
transform::ApplyToEachResultList &results,
3721+
transform::TransformState &state) {
3722+
rewriter.setInsertionPoint(target);
3723+
FailureOr<Operation *> maybeTransformed = failure();
3724+
bool supported = TypeSwitch<Operation *, bool>(target)
3725+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3726+
maybeTransformed =
3727+
winogradConv2D(rewriter, op, getM(), getR());
3728+
return true;
3729+
})
3730+
.Default([&](Operation *op) { return false; });
3731+
3732+
if (!supported) {
3733+
return emitSilenceableError()
3734+
<< "this operation is not supported to convert to Winograd Conv2D";
3735+
}
3736+
3737+
if (supported && failed(maybeTransformed)) {
3738+
return emitSilenceableError() << "apply Winograd Conv2D failed";
3739+
}
3740+
3741+
results.push_back(*maybeTransformed);
3742+
return DiagnosedSilenceableFailure::success();
3743+
}
3744+
37143745
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
37153746

37163747
#define GET_OP_CLASSES

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1616
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
1819
#include "mlir/Dialect/Utils/StaticValueUtils.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1921
#include "llvm/Support/MathExtras.h"
2022

2123
namespace mlir {
@@ -156,7 +158,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp,
156158
auto filterType = cast<ShapedType>(filter.getType());
157159
auto outputType = cast<ShapedType>(output.getType());
158160

159-
// TODO: Should we support dynamic shapes?
160161
if (!inputType.hasStaticShape())
161162
return rewriter.notifyMatchFailure(convOp,
162163
"expected a static shape for the input");
@@ -316,6 +317,12 @@ class WinogradConv2DNhwcFhwc final
316317
} // end anonymous namespace
317318

318319
//===----------------------------------------------------------------------===//
320+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
321+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
322+
int64_t r) {
323+
return winogradConv2DHelper(rewriter, op, m, r);
324+
}
325+
319326
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
320327
int64_t r) {
321328
MLIRContext *context = patterns.getContext();
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
2+
3+
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
4+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
5+
return %0 : tensor<2x8x8x2xf32>
6+
}
7+
8+
module attributes {transform.with_named_sequence} {
9+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
10+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
11+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
12+
transform.yield
13+
}
14+
}
15+
16+
// CHECK-LABEL: func.func @conv2d
17+
// CHECK: linalg.winograd_filter_transform m(4) r(3)
18+
// CHECK: linalg.winograd_input_transform m(4) r(3)
19+
// CHECK: linalg.batch_matmul
20+
// CHECK: linalg.winograd_output_transform m(4) r(3)
21+
22+
// -----
23+
24+
func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
25+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
26+
return %0 : tensor<2x9x9x2xf32>
27+
}
28+
29+
module attributes {transform.with_named_sequence} {
30+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
31+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
33+
transform.yield
34+
}
35+
}
36+
37+
// CHECK-LABEL: func.func @conv2d_unaligned
38+
// CHECK: linalg.winograd_filter_transform m(4) r(3)
39+
// CHECK: tensor.pad
40+
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
41+
// CHECK: linalg.winograd_input_transform m(4) r(3)
42+
// CHECK: tensor.pad
43+
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
44+
// CHECK: linalg.winograd_output_transform m(4) r(3)
45+
46+
// -----
47+
48+
func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
49+
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
50+
return %0 : tensor<2x8x8x2xf32>
51+
}
52+
53+
module attributes {transform.with_named_sequence} {
54+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
// expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
57+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
58+
transform.yield
59+
}
60+
}
61+
62+
// -----
63+
64+
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
65+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
66+
return %0 : tensor<2x?x?x2xf32>
67+
}
68+
69+
module attributes {transform.with_named_sequence} {
70+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
71+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
72+
// expected-error @+1 {{apply Winograd Conv2D failed}}
73+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
74+
transform.yield
75+
}
76+
}

0 commit comments

Comments
 (0)