Skip to content

Commit 968bf01

Browse files
committed
[mlir][linalg] Implement Winograd Conv2D.
This patch implements the Winograd Conv2D algorithm. It supports several configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2, 5). These configurations show that the implementation can support different kernel size (3 and 5) and different output size (2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3, 3x1, 1x5, and 5x1 kernels. The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
1 parent bd5fbab commit 968bf01

File tree

12 files changed

+2103
-0
lines changed

12 files changed

+2103
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,96 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
158+
[DeclareOpInterfaceMethods<TilingInterface,
159+
["getIterationDomain",
160+
"getLoopIteratorTypes",
161+
"getResultTilePosition",
162+
"getTiledImplementation"]>]> {
163+
let summary = "Winograd filter transform operator";
164+
let description = [{
165+
linalg.winograd_filter_transform transforms the filter of conv2D.
166+
}];
167+
168+
let arguments = (ins AnyRankedTensor:$filter,
169+
AnyRankedTensor:$output,
170+
I64Attr:$output_height,
171+
I64Attr:$output_width,
172+
I64Attr:$m,
173+
I64Attr:$r
174+
);
175+
176+
let results = (outs AnyRankedTensor:$result);
177+
let assemblyFormat = [{
178+
attr-dict
179+
`output_height` `(` $output_height `)`
180+
`output_width` `(` $output_width `)`
181+
`m` `(` $m `)`
182+
`r` `(` $r `)`
183+
`ins` `(` $filter `:` type($filter) `)`
184+
`outs` `(` $output `:` type($output) `)`
185+
`->` type($result)
186+
}];
187+
}
188+
189+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
190+
[DeclareOpInterfaceMethods<TilingInterface,
191+
["getIterationDomain",
192+
"getLoopIteratorTypes",
193+
"getResultTilePosition",
194+
"getTiledImplementation"]>]> {
195+
let summary = "Winograd input transform operator";
196+
let description = [{
197+
linalg.winograd_input_transform transforms the input of conv2D.
198+
}];
199+
200+
let arguments = (ins AnyRankedTensor:$input,
201+
AnyRankedTensor:$output,
202+
I64Attr:$output_height,
203+
I64Attr:$output_width,
204+
I64Attr:$m,
205+
I64Attr:$r
206+
);
207+
208+
let results = (outs AnyRankedTensor:$result);
209+
let assemblyFormat = [{
210+
attr-dict
211+
`output_height` `(` $output_height `)`
212+
`output_width` `(` $output_width `)`
213+
`m` `(` $m `)`
214+
`r` `(` $r `)`
215+
`ins` `(` $input `:` type($input) `)`
216+
`outs` `(` $output `:` type($output) `)`
217+
`->` type($result)
218+
}];
219+
}
220+
221+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
222+
[DeclareOpInterfaceMethods<TilingInterface,
223+
["getIterationDomain",
224+
"getLoopIteratorTypes",
225+
"getResultTilePosition",
226+
"getTiledImplementation"]>]> {
227+
let summary = "Winograd output transform operator";
228+
let description = [{
229+
linalg.winograd_output_transform transforms the output of conv2D.
230+
}];
231+
232+
let arguments = (ins AnyRankedTensor:$value,
233+
AnyRankedTensor:$output,
234+
I64Attr:$m,
235+
I64Attr:$r
236+
);
237+
238+
let results = (outs AnyRankedTensor:$result);
239+
let assemblyFormat = [{
240+
attr-dict
241+
`m` `(` $m `)`
242+
`r` `(` $r `)`
243+
`ins` `(` $value `:` type($value) `)`
244+
`outs` `(` $output `:` type($output) `)`
245+
`->` type($result)
246+
}];
247+
}
248+
157249
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,84 @@ def MapCopyToThreadsOp :
25872587
}];
25882588
}
25892589

2590+
//===----------------------------------------------------------------------===//
2591+
// Winograd Conv2D
2592+
//===----------------------------------------------------------------------===//
2593+
2594+
def WinogradConv2DOp : Op<Transform_Dialect,
2595+
"structured.winograd_conv2d",
2596+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2597+
TransformOpInterface, TransformEachOpTrait,
2598+
ReportTrackingListenerFailuresOpTrait]> {
2599+
let description = [{
2600+
Use Winograd Conv2D algorithm to compute Conv2D. It will decompose conv2d
2601+
into three transform operators, i.e., filter transform operator, input
2602+
transform operator and output transform operator. In addition, use batched
2603+
matmul to compute the transformed filter and input matrices.
2604+
2605+
#### Return modes:
2606+
2607+
This operation fails if `target` is unsupported. Otherwise, the operation
2608+
succeeds and returns a handle of the sequence that replaces the original
2609+
convolution.
2610+
}];
2611+
2612+
let arguments = (ins TransformHandleTypeInterface:$target,
2613+
I64Attr:$m,
2614+
I64Attr:$r);
2615+
let results = (outs TransformHandleTypeInterface:$transformed);
2616+
2617+
let assemblyFormat =
2618+
"$target attr-dict `:` functional-type($target, results)";
2619+
2620+
let builders = [
2621+
OpBuilder<(ins "Value":$target)>
2622+
];
2623+
2624+
let extraClassDeclaration = [{
2625+
::mlir::DiagnosedSilenceableFailure applyToOne(
2626+
::mlir::transform::TransformRewriter &rewriter,
2627+
::mlir::linalg::LinalgOp target,
2628+
::mlir::transform::ApplyToEachResultList &results,
2629+
::mlir::transform::TransformState &state);
2630+
}];
2631+
}
2632+
2633+
def WinogradConv2DRewriteOp : Op<Transform_Dialect,
2634+
"structured.winograd_conv2d_rewrite",
2635+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2636+
TransformOpInterface, TransformEachOpTrait,
2637+
ReportTrackingListenerFailuresOpTrait]> {
2638+
let description = [{
2639+
Rewrite winograd conv2D operators. It will convert filter, input and
2640+
output transform operators into a combination of scf, tensor, and linalg
2641+
equivalent operators. Before applying this transform operator, users
2642+
need to tile winograd transform operators into supported sizes.
2643+
2644+
#### Return modes:
2645+
2646+
This operation fails if `target` is unsupported. Otherwise, the operation
2647+
succeeds and returns a handle of the sequence that replaces the original
2648+
operator.
2649+
}];
2650+
2651+
let arguments = (ins TransformHandleTypeInterface:$target);
2652+
let results = (outs TransformHandleTypeInterface:$transformed);
2653+
2654+
let assemblyFormat =
2655+
"$target attr-dict `:` functional-type($target, results)";
2656+
2657+
let builders = [
2658+
OpBuilder<(ins "Value":$target)>
2659+
];
2660+
2661+
let extraClassDeclaration = [{
2662+
::mlir::DiagnosedSilenceableFailure applyToOne(
2663+
::mlir::transform::TransformRewriter &rewriter,
2664+
::mlir::Operation *target,
2665+
::mlir::transform::ApplyToEachResultList &results,
2666+
::mlir::transform::TransformState &state);
2667+
}];
2668+
}
2669+
25902670
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,20 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13121312
linalg::BatchMatmulOp op,
13131313
bool transposeLHS = true);
13141314

1315+
/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm.
1316+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1317+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1318+
int64_t r);
1319+
FailureOr<Operation *>
1320+
winogradConv2DRewriteFilterTransform(RewriterBase &rewriter,
1321+
linalg::WinogradFilterTransformOp op);
1322+
FailureOr<Operation *>
1323+
winogradConv2DRewriteInputTransform(RewriterBase &rewriter,
1324+
linalg::WinogradInputTransformOp op);
1325+
FailureOr<Operation *>
1326+
winogradConv2DRewriteOutputTransform(RewriterBase &rewriter,
1327+
linalg::WinogradOutputTransformOp op);
1328+
13151329
//===----------------------------------------------------------------------===//
13161330
// Rewrite patterns wrapping transformations.
13171331
// TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1706,11 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16921706
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
16931707
const ControlBlockPackMatmulFn &controlFn);
16941708

1709+
/// Patterns to apply Winograd Conv2D algorithm.
1710+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1711+
int64_t r);
1712+
void populateWinogradConv2DRewritePatterns(RewritePatternSet &patterns);
1713+
16951714
} // namespace linalg
16961715
} // namespace mlir
16971716

0 commit comments

Comments
 (0)