Skip to content

[mlir][linalg] Decompose winograd operators #96183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Jul 18, 2024

Conversation

Hsiangkai
Copy link
Contributor

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

Changes

Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)


Patch is 99.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96183.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+114)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+51)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+14)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+78)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+25)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+1100)
  • (added) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+88)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+105)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+248)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+24)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..de1097b6ac27b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..68d0f713caad4 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..bb7ec590faad0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1699,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..7bf2a5bca037f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto filterElemType = filterType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (filterElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << filterElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned filterRank = filterType.getRank();
+  if (filterRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto inputElemType = inputType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (inputElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << inputElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned inputRank = inputType.getRank();
+  if (inputRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradOutputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradOutputTransformOp::verify() {
+  auto valueType = cast<ShapedType>(getValue().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto valueElemType = valueType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (valueElemType != outputElemType) {
+    return emitOpError() << "expected element type of value " << valueElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned valueRank = valueType.getRank();
+  if (valueRank != 6)
+    return emitOpError() << "expected rank of input is 6";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 4)
+    return emitOpError() << "expected rank of output is 4";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc02788f9c441..d051b29e1f06f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return winogradConv2D(rewriter, op, getM(), getR());
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..d245723c85646
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,1100 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implement Winograd Conv2D algorithm. The implementation is based on the
+// paper: Fast Algorithms for Convolutional Neural Networks
+// (https://arxiv.org/abs/1509.09308)
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type...
[truncated]

Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
@Hsiangkai Hsiangkai force-pushed the users/hsiangkai/winograd-ops-transform branch from ab54cf8 to 374b0d5 Compare June 20, 2024 12:49
Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Comment on lines +41 to +52
constexpr float G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
1./2, 1./2, 1./2,
0, 0, 1
};

constexpr float GT_2x2_3x3[] = {
-1, 1./2, 1./2, 0,
0, -1./2, 1./2, 0,
0, 1./2, 1./2, 1
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered introducing a (potentially constexpr) transpose function or some sort of transposed access iterator instead of hardcoding transposed matrices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate it a bit more? I am not sure what the idea is here. Thank you.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the matrices for G and GT (transposed) are both hardcoded, and it's not ideal from code review and maintenance point of view. I'd much rather have the compiler do it for me somehow.

This is not a hard requirement as I don't expect these matrices to ever change.

Hsiangkai added 15 commits June 26, 2024 09:43
Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Comment on lines +41 to +52
constexpr float G_2x2_3x3[] = {
-1, 0, 0,
1./2, -1./2, 1./2,
1./2, 1./2, 1./2,
0, 0, 1
};

constexpr float GT_2x2_3x3[] = {
-1, 1./2, 1./2, 0,
0, -1./2, 1./2, 0,
0, 1./2, 1./2, 1
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the matrices for G and GT (transposed) are both hardcoded, and it's not ideal from code review and maintenance point of view. I'd much rather have the compiler do it for me somehow.

This is not a hard requirement as I don't expect these matrices to ever change.

Base automatically changed from users/hsiangkai/winograd-ops-transform to main July 11, 2024 13:45
Copy link

github-actions bot commented Jul 11, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for continuing to work on this!

Overall, the decomposition seems okay for the way it is being implemented, but I have some comments about how this will connect to the tiling implementation. The general idea is that the decomposition can either require all dims other than the input tile to be 1, or require none of them to be 1. See my comment on the input transform for more details, but this applies to the output transform as well.

@Hsiangkai Hsiangkai requested a review from Max191 July 17, 2024 14:15
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for all the great work so far!

@Hsiangkai
Copy link
Contributor Author

Looks good! Thanks for all the great work so far!

Thank you all for your review and recommendation. It's very helpful and I learnt a lot from it.

@Hsiangkai Hsiangkai merged commit 27ee33d into main Jul 18, 2024
7 checks passed
@Hsiangkai Hsiangkai deleted the users/hsiangkai/winograd-decompose branch July 18, 2024 05:04
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary:
Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

Test Plan: 

Reviewers: 

Reviewed By: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250998
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants