Skip to content

[mlir][linalg] Implement Winograd Conv2D. #94470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

Hsiangkai
Copy link
Contributor

This patch implements the Winograd Conv2D algorithm. It supports several configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2, 5). These configurations show that the implementation can support different kernel size (3 and 5) and different output size (2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

Changes

This patch implements the Winograd Conv2D algorithm. It supports several configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2, 5). These configurations show that the implementation can support different kernel size (3 and 5) and different output size (2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)


Patch is 88.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94470.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+830)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+570)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+12)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..a2f543400be85 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1692,6 +1692,9 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm.
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 7e3dc56e0acdc..a7dcc29b5b9be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Transforms.cpp
   TransposeConv2D.cpp
   Vectorization.cpp
+  WinogradConv2D.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
new file mode 100644
index 0000000000000..07a1b55ff8813
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -0,0 +1,830 @@
+//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// Implement Winograd Conv2D algorithm. The implementation is based on the
+/// paper: Fast Algorithms for Convolutional Neural Networks
+/// (https://arxiv.org/abs/1509.09308)
+///
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace linalg {
+
+namespace {
+
+// clang-format off
+// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its
+// result. The formula of minimal 2D filtering algorithm F(m x m, r x r),
+// m is the output dimension and r is the filter dimension, is
+//
+// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A
+//
+// g is filter and d is input data. We need to prepare 6 constant
+// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula.
+//
+// The following tables define these constant transformation matrices for
+// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5)
+constexpr float G_2x2_3x3[] = {
+   -1,     0,   0,
+ 1./2, -1./2, 1./2,
+ 1./2,  1./2, 1./2,
+    0,     0,    1
+};
+
+constexpr float GT_2x2_3x3[] = {
+   -1,  1./2, 1./2, 0,
+    0, -1./2, 1./2, 0,
+    0,  1./2, 1./2, 1
+};
+
+constexpr float BT_2x2_3x3[] = {
+   -1,    0,   1,   0,
+    0,   -1,   1,   0,
+    0,    1,   1,   0,
+    0,   -1,   0,   1
+};
+
+constexpr float B_2x2_3x3[] = {
+   -1,    0,   0,   0,
+    0,   -1,   1,  -1,
+    1,    1,   1,   0,
+    0,    0,   0,   1
+};
+
+constexpr float AT_2x2_3x3[] = {
+    1,    1,   1,   0,
+    0,   -1,   1,   1
+};
+
+constexpr float A_2x2_3x3[] = {
+    1,    0,
+    1,   -1,
+    1,    1,
+    0,    1
+};
+
+constexpr float G_4x4_3x3[] = {
+     1,     0,     0,
+ -1./3,  1./3, -1./3,
+ -1./3, -1./3, -1./3,
+ 1./12, -1./6,  1./3,
+ 1./12,  1./6,  1./3,
+     0,     0,     1
+};
+
+constexpr float GT_4x4_3x3[] = {
+ 1,  -1./3, -1./3, 1./12, 1./12, 0,
+ 0,   1./3, -1./3, -1./6,  1./6, 0,
+ 0,  -1./3, -1./3,  1./3,  1./3, 1
+};
+
+constexpr float BT_4x4_3x3[] = {
+ 1./4,     0, -5./16,      0, 1./16,     0,
+    0,  1./4,  -1./4, -1./16, 1./16,     0,
+    0, -1./4,  -1./4,  1./16, 1./16,     0,
+    0,  1./4,  -1./8,  -1./4,  1./8,     0,
+    0, -1./4,  -1./8,   1./4,  1./8,     0,
+    0,  1./4,      0, -5./16,     0, 1./16
+};
+
+constexpr float B_4x4_3x3[] = {
+   1./4,      0,     0,     0,     0,      0,
+      0,   1./4, -1./4,  1./4, -1./4,   1./4,
+ -5./16,  -1./4, -1./4, -1./8, -1./8,      0,
+      0, -1./16, 1./16, -1./4,  1./4, -5./16,
+  1./16,  1./16, 1./16,  1./8,  1./8,      0,
+      0,      0,     0,     0,     0,  1./16
+};
+
+constexpr float AT_4x4_3x3[] = {
+ 1./8,  1./4, 1./4,  1./8, 1./8,    0,
+    0, -1./4, 1./4, -1./4, 1./4,    0,
+    0,  1./4, 1./4,  1./2, 1./2,    0,
+    0, -1./4, 1./4,    -1,    1, 1./2
+};
+
+constexpr float A_4x4_3x3[] = {
+  1./8,     0,    0,     0,
+  1./4, -1./4, 1./4, -1./4,
+  1./4,  1./4, 1./4,  1./4,
+  1./8, -1./4, 1./2,    -1,
+  1./8,  1./4, 1./2,     1,
+     0,     0,    0,  1./2
+};
+
+constexpr float G_2x2_5x5[] = {
+     1,     0,      0,      0,      0,
+  1./6, -1./6,   1./6,  -1./6,   1./6,
+ -1./6, -1./6,  -1./6,  -1./6,  -1./6,
+-4./15, 2./15, -1./15,  1./30, -1./60,
+ 1./60, 1./30,  1./15,  2./15,  4./15,
+     0,     0,      0,      0,      1
+};
+
+constexpr float GT_2x2_5x5[] = {
+   1,  1./6, -1./6, -4./15, 1./60, 0,
+   0, -1./6, -1./6,  2./15, 1./30, 0,
+   0,  1./6, -1./6, -1./15, 1./15, 0,
+   0, -1./6, -1./6,  1./30, 2./15, 0,
+   0,  1./6, -1./6, -1./60, 4./15, 1
+};
+
+constexpr float BT_2x2_5x5[] = {
+ 1./8,  3./16,  -1./4,  -3./16,   1./8,    0,
+    0,   1./8,  1./16,  -5./16,   1./8,    0,
+    0,  -1./8, -5./16,  -1./16,   1./8,    0,
+    0,   1./4,  -1./8,   -1./4,   1./8,    0,
+    0,  -1./8,  -1./4,    1./8,   1./4,    0,
+    0,   1./8,  3./16,   -1./4, -3./16, 1./8
+};
+
+constexpr float B_2x2_5x5[] = {
+   1./8,      0,      0,     0,     0,      0,
+  3./16,   1./8,  -1./8,  1./4, -1./8,   1./8,
+  -1./4,  1./16, -5./16, -1./8, -1./4,  3./16,
+ -3./16, -5./16, -1./16, -1./4,  1./8,  -1./4,
+   1./8,   1./8,   1./8,  1./8,  1./4, -3./16,
+      0,      0,      0,     0,     0,   1./8
+};
+
+constexpr float AT_2x2_5x5[] = {
+  1./2,  1, 1,  2, 1,    0,
+     0, -1, 1, -1, 2, 1./2
+};
+
+constexpr float A_2x2_5x5[] = {
+ 1./2,    0,
+    1,   -1,
+    1,    1,
+    2,   -1,
+    1,    2,
+    0, 1./2
+};
+// clang-format on
+
+using TransformMapKeyTy = std::pair<int, int>;
+
+// We use F(m, r) to define the size of minimal filtering algorithms.
+// m is the output dimension and r is the filter dimension. We can get
+// the input dimension, alpha, from the formula, alpha = m + r - 1.
+//
+// For example, when m = 2 and r = 3, we know its input size is 4.
+// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+// 2x2 output result.
+constexpr TransformMapKeyTy F_2_3{2, 3};
+constexpr TransformMapKeyTy F_4_3{4, 3};
+constexpr TransformMapKeyTy F_2_5{2, 5};
+
+struct TransformMatrix {
+  TransformMatrix(const float *table, int64_t rows, int64_t cols,
+                  int64_t scalarFactor = 1)
+      : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
+
+  const float *table;
+  int64_t rows;
+  int64_t cols;
+  int64_t scalarFactor;
+};
+
+// Map from (m, r) to G transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> GMatrices = {
+    {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+    {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+    {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+};
+
+// Map from (m, r) to GT transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> GTMatrices = {
+    {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+    {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+    {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+};
+
+// Map from (m, r) to BT transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> BTMatrices = {
+    {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+    {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+    {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+};
+
+// Map from (m, r) to B transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> BMatrices = {
+    {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+    {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+    {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+};
+
+// Map from (m, r) to AT transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> ATMatrices = {
+    {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)},
+    {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)},
+    {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)},
+};
+
+// Map from (m, r) to A transform matrix.
+const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> AMatrices = {
+    {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)},
+    {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)},
+    {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)},
+};
+
+Value create2DTransformMatrix(RewriterBase &rewriter, Location loc,
+                              TransformMatrix transform, Type type) {
+  ArrayRef<float> const_vec(transform.table, transform.rows * transform.cols);
+
+  return rewriter.create<arith::ConstantOp>(
+      loc, DenseFPElementsAttr::get(
+               RankedTensorType::get(
+                   SmallVector<int64_t>{transform.rows, transform.cols}, type),
+               const_vec));
+}
+
+// This function transforms the filter. The data layout of the filter is FHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// FHWC first. We need to generate 2 levels of loops to iterate on F and C.
+// After the transformation, we get
+//
+// scf.for %f = lo_f to hi_f step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract filter<h x w> from filter<f x h x w x c>
+//     %ret = linalg.matmul G, %extracted
+//     %ret = linalg.matmul %ret, GT
+//     %inserted = insert %ret into filter<h x w x c x f>
+//
+Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
+                      int64_t outputH, int64_t outputW,
+                      bool leftTransform = true, bool rightTransform = true) {
+  auto filterType = cast<ShapedType>(filter.getType());
+  Type elementType = filterType.getElementType();
+  auto filterShape = filterType.getShape(); // F, H, W, C
+  int64_t filterF = filterShape[0];
+  int64_t filterH = filterShape[1];
+  int64_t filterW = filterShape[2];
+  int64_t filterC = filterShape[3];
+  int64_t alphaH = outputH + filterH - 1;
+  int64_t alphaW = outputW + filterW - 1;
+
+  // Return shape is <H x W x C x F>
+  auto retType =
+      RankedTensorType::get({alphaH, alphaW, filterC, filterF}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value FIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (1, H, W, 1) from (F, H, W, C)
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets = {FIter, zeroIndex, zeroIndex, CIter};
+  SmallVector<OpFoldResult, 4> sizes = {oneIndex,                       // F
+                                        rewriter.getIndexAttr(filterH), // H
+                                        rewriter.getIndexAttr(filterW), // W
+                                        oneIndex};                      // C
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto targetType =
+      RankedTensorType::get({1, filterH, filterW, 1}, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, filter, offsets, sizes, strides);
+
+  // Extract (H, W) from (1, H, W, 1)
+  // g = extracted (H, W)
+  auto extractFilterType =
+      RankedTensorType::get({filterH, filterW}, elementType);
+  auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractFilterType);
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? filterH : filterW};
+  int64_t retRows = 1;
+  Value matmulRetValue = extractFilter;
+  if (leftTransform) {
+    // Get constant transform matrix G
+    auto it = GMatrices.find(key);
+    assert(it != GMatrices.end());
+    const TransformMatrix &GMatrix = it->second;
+
+    retRows = GMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType);
+    // Multiply G x g
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix GT
+    auto it = GTMatrices.find(key);
+    assert(it != GTMatrices.end());
+    const TransformMatrix &GTMatrix = it->second;
+
+    auto matmulType =
+        RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType);
+    // Multiply u = (G x g) x GT
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  // Insert u
+  // Insert (H, W) to (H, W, 1, 1)
+  auto sliceType = RankedTensorType::get({alphaH, alphaW, 1, 1}, elementType);
+  auto init =
+      rewriter.create<tensor::EmptyOp>(loc, sliceType.getShape(), elementType);
+  auto result = tensor::createCanonicalRankReducingInsertSliceOp(
+      rewriter, loc, matmulRetValue, init);
+
+  // Insert (H, W, 1, 1) to (H, W, C, F)
+  SmallVector<OpFoldResult, 4> retOffsets = {zeroIndex, zeroIndex, CIter,
+                                             FIter};
+  SmallVector<OpFoldResult, 4> retSizes = {rewriter.getIndexAttr(alphaH),
+                                           rewriter.getIndexAttr(alphaW),
+                                           oneIndex, oneIndex};
+
+  Value iterArg = innerForOp.getRegionIterArgs()[0];
+  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
+      loc, result, iterArg, retOffsets, retSizes, strides);
+
+  rewriter.create<scf::YieldOp>(loc, insertSliceOp.getResult());
+
+  rewriter.setInsertionPointToEnd(outerForBody);
+  rewriter.create<scf::YieldOp>(loc, innerForOp.getResult(0));
+
+  rewriter.setInsertionPointAfter(outerForOp);
+
+  return outerForOp.getResult(0);
+}
+
+// This function transforms the input. The data layout of the input is NHWC.
+// The transformation matrix is 2-dimension. We need to extract H x W from
+// NHWC first. We need to generate 2 levels of loops to iterate on N and C.
+// After the transformation, we get
+//
+// scf.for %n = lo_n to hi_n step 1
+//   scf.for %c = lo_c to hi_c step 1
+//     %extracted = extract input<h x w> from input<n x h x w x c>
+//     %ret = linalg.matmul BT, %extracted
+//     %ret = linalg.matmul %ret, B
+//     %inserted = insert %ret into input<h x w x n x c>
+//
+Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
+                     int64_t outputH, int64_t outputW,
+                     bool leftTransform = true, bool rightTransform = true) {
+  auto inputType = cast<ShapedType>(input.getType());
+  Type elementType = inputType.getElementType();
+  auto inputShape = inputType.getShape(); // N, H, W, C
+  int64_t inputN = inputShape[0];
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t inputC = inputShape[3];
+
+  auto retType =
+      RankedTensorType::get({inputH, inputW, inputN, inputC}, elementType);
+  Value retValue =
+      rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
+
+  auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN);
+  auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC);
+  auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+
+  auto outerForOp =
+      rewriter.create<scf::ForOp>(loc, zeroIdx, nUpperBound, oneStep, retValue);
+  Block *outerForBody = outerForOp.getBody();
+  rewriter.setInsertionPointToStart(outerForBody);
+  Value NIter = outerForBody->getArgument(0);
+
+  auto innerForOp = rewriter.create<scf::ForOp>(
+      loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]);
+  Block *innerForBody = innerForOp.getBody();
+  rewriter.setInsertionPointToStart(innerForBody);
+  Value CIter = innerForBody->getArgument(0);
+
+  // Extract (1, H, W, 1) from (N, H, W, C)
+  auto zeroIndex = rewriter.getIndexAttr(0);
+  auto oneIndex = rewriter.getIndexAttr(1);
+  SmallVector<OpFoldResult, 4> offsets = {NIter, zeroIndex, zeroIndex, CIter};
+  SmallVector<OpFoldResult, 4> sizes = {oneIndex,                      // F
+                                        rewriter.getIndexAttr(inputH), // H
+                                        rewriter.getIndexAttr(inputW), // W
+                                        oneIndex};                     // C
+  SmallVector<OpFoldResult, 4> strides(4, oneIndex);
+
+  auto targetType = RankedTensorType::get({1, inputH, inputW, 1}, elementType);
+  auto extractFilterOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, targetType, input, offsets, sizes, strides);
+
+  // Extract (H, W) from (1, H, W, 1)
+  // d = extracted (H, W)
+  auto extractInputType = RankedTensorType::get({inputH, inputW}, elementType);
+  auto extractInput = tensor::createCanonicalRankReducingExtractSliceOp(
+      rewriter, loc, extractFilterOp, extractInputType);
+
+  TransformMapKeyTy key = {leftTransform ? outputH : outputW,
+                           leftTransform ? inputH - outputH + 1
+                                         : inputW - outputW + 1};
+  int64_t retRows = 1;
+  int64_t retCols = 1;
+  Value matmulRetValue = extractInput;
+  if (leftTransform) {
+    // Get constant transform matrix BT
+    auto it = BTMatrices.find(key);
+    assert(it != BTMatrices.end());
+    const TransformMatrix &BTMatrix = it->second;
+
+    retRows = BTMatrix.rows;
+    auto matmulType = RankedTensorType::get({retRows, inputW}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+
+    Value BT =
+        create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type());
+    // Multiply BT x d
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init});
+    matmulRetValue = matmulOp.getResult(0);
+  }
+
+  if (rightTransform) {
+    // Get constant transform matrix B
+    auto it = BMatrices.find(key);
+    assert(it != BMatrices.end());
+    const TransformMatrix &BMatrix = it->second;
+
+    retCols = BMatrix.cols;
+    auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
+    auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
+                                                 elementType);
+    Value B =
+        create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type());
+    // Multiply v = (BT x d) x B
+    auto matmulOp = rewriter.create<linalg::MatmulOp>(
+        loc, matmulType, ValueRange{matmulRetVa...
[truncated]

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to get more input on this, but just for context, in IREE we handled this slightly differently. We implemented an operation for the input/filter/output transforms for winograd https://github.com/iree-org/iree/blob/9d60462ebe1bd5dcd0cc8cb8ca7a5e45523c1bd4/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td#L983 . These implement the TilingInterface which allow executing these operations better (ability to distribute to threads, etc.) . They eventually get decomposed into what is being implemented here. This has worked well in IREE. I would suggest we try to "upstream" what is being done in IREE.

cc @Max191 @bjacob and @harsh-nod

@Hsiangkai
Copy link
Contributor Author

I'd like to get more input on this, but just for context, in IREE we handled this slightly differently. We implemented an operation for the input/filter/output transforms for winograd https://github.com/iree-org/iree/blob/9d60462ebe1bd5dcd0cc8cb8ca7a5e45523c1bd4/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td#L983 . These implement the TilingInterface which allow executing these operations better (ability to distribute to threads, etc.) . They eventually get decomposed into what is being implemented here. This has worked well in IREE. I would suggest we try to "upstream" what is being done in IREE.

cc @Max191 @bjacob and @harsh-nod

Thanks for your information. I didn't dig into the implementation of IREE. I assume there exists some pass to tile the input into the size supported by Winograd Conv2D. (I might be wrong. I will figure it out.) I will take a look of TilingInterface, too. If you all agree to "upstream" IREE implementation to MLIR, I am fine to revoke this patch.

@MaheshRavishankar
Copy link
Contributor

I'd like to get more input on this, but just for context, in IREE we handled this slightly differently. We implemented an operation for the input/filter/output transforms for winograd https://github.com/iree-org/iree/blob/9d60462ebe1bd5dcd0cc8cb8ca7a5e45523c1bd4/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td#L983 . These implement the TilingInterface which allow executing these operations better (ability to distribute to threads, etc.) . They eventually get decomposed into what is being implemented here. This has worked well in IREE. I would suggest we try to "upstream" what is being done in IREE.
cc @Max191 @bjacob and @harsh-nod

Thanks for your information. I didn't dig into the implementation of IREE. I assume there exists some pass to tile the input into the size supported by Winograd Conv2D. (I might be wrong. I will figure it out.) I will take a look of TilingInterface, too. If you all agree to "upstream" IREE implementation to MLIR, I am fine to revoke this patch.

Yes, thats exactly what it does. We would like to upstream it, but not sure we have the bandwidth to do this. (Hence I havent marked this as change requested). We'd be happy to co-ordinate on upstreaming what is in IREE...

@Hsiangkai Hsiangkai requested a review from ftynse as a code owner June 6, 2024 21:46
@Hsiangkai
Copy link
Contributor Author

I updated the patch to include structured.winograd_conv2d transform op and added a test case to show how to use structured.tile_using_for and structured.winograd_conv2d to deal with conv2d with arbitrary input size.

@bjacob
Copy link
Contributor

bjacob commented Jun 7, 2024

See also iree-org/iree#16571.

@Hsiangkai
Copy link
Contributor Author

We are willing to coordinate with you to upstream Winograd Conv2D algorithm. I have updated our implementation using TilingInterface. How about to use this version as a start point for the upstream work? If there is anything that I can do to align IREE’s implementation or meet IREE’s requirements, please tell me and I can help on it.

@Hsiangkai Hsiangkai force-pushed the winograd-conv2d branch 2 times, most recently from f0d0927 to 968bf01 Compare June 14, 2024 13:39
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! @Max191 can you help review this and shepherd this

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! This is a rather big change so I would strongly suggest splitting it into several parts if you don't want to be blocked waiting on reviewers to find time to get through 2k LoC. For example:

  • introduce the data modification operations + tests;
  • implement the tiling interface for them + tests;
  • implement the transformation and one way of executing it (test pass or transform dialect);
  • implement the other way of executing it.

This will let reviewers focus on relevant aspects and not require committing a big block of time.

I left a couple of stylistic comments. The big one is to please document all functions and operations extensively.

If there is a design discussion to be had, please do so on https://discourse.llvm.org/c/mlir/31 rather than in PR comments. In particular, choosing the design for upstream MLIR is not conditioned on the needs of a particular downstream, however large.

Comment on lines +172 to +185
I64Attr:$m,
I64Attr:$r
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use more descriptive names.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m and r are variable names in the Winograd algorithm. I added more description for the operators.

@@ -1692,6 +1706,11 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Patterns to apply Winograd Conv2D algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between 2DPatterns and 2DRewritePattersn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rename 2DRewritePatterns to DecomposeWinogradOpsPatterns.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on adding these ops upstream, and with more options on kernel size/input tile size than what we have in IREE! Just some higher level op semantics comments for now.

The main difference I see between this implementation and the IREE implementation is that the dimensionality of the transforms are different. In this implementation, the 4D input tensor is transformed into another 4D transformed tensor, while in the IREE implementation, a 6D input tensor is produced. The extra dimensionality comes from having (alpha)x(alpha) dimensions expanded from the H and W dimensions. In other words, with an NHWC layout, the IREE input transform would produce tensor<TxTxNxceil(H/T)xceil(W/T)xC>.

This extra dimensionality is very useful for tiling, since the iteration space is actually over the innermost of those 6 dimensions. Trying to implement a tiled implementation where some of those dimensions are collapsed is tricky.

I would suggest reworking the op semantics a bit to allow for this extra dimensionality. Then, the tiled implementation from IREE should also just drop right in here. Also, another benefit of having separate input tile dimensions is that the input tile dimensions can be innermost if desired. This can potentially provide performance benefits from having continuous accesses in the winograd ops. We have experimentally tried this in IREE, but have not merged the change yet.

The other difference I notice, is that there is no implicit padding on the input transform. This means that for shapes that are unaligned with the input tiles, the input transform will not capture the last partial tile of the input tensor. Instead the input transform should extract the last partial tile, and pad the slice with zeros to the input tile size (alpha).

The last comment is that the implementation here has explicit I64Attrs for the output height and width. This should not be necessary if the ops have padding semantics as suggested above, since there is no need to know the output shape for the input or filter transform in that case. Having this attribute means that the op is restricted to static image sizes only (making this an operand could allow some dynamic support, but it is better not to rely on having these sizes since it is not necessary).

@Hsiangkai
Copy link
Contributor Author

Thank you for your contribution! This is a rather big change so I would strongly suggest splitting it into several parts if you don't want to be blocked waiting on reviewers to find time to get through 2k LoC. For example:

  • introduce the data modification operations + tests;
  • implement the tiling interface for them + tests;
  • implement the transformation and one way of executing it (test pass or transform dialect);
  • implement the other way of executing it.

This will let reviewers focus on relevant aspects and not require committing a big block of time.

I left a couple of stylistic comments. The big one is to please document all functions and operations extensively.

If there is a design discussion to be had, please do so on https://discourse.llvm.org/c/mlir/31 rather than in PR comments. In particular, choosing the design for upstream MLIR is not conditioned on the needs of a particular downstream, however large.

Thanks for your review. I split the original patch into 4 smaller commits and address all your comments.

@Hsiangkai
Copy link
Contributor Author

Thanks for working on adding these ops upstream, and with more options on kernel size/input tile size than what we have in IREE! Just some higher level op semantics comments for now.

The main difference I see between this implementation and the IREE implementation is that the dimensionality of the transforms are different. In this implementation, the 4D input tensor is transformed into another 4D transformed tensor, while in the IREE implementation, a 6D input tensor is produced. The extra dimensionality comes from having (alpha)x(alpha) dimensions expanded from the H and W dimensions. In other words, with an NHWC layout, the IREE input transform would produce tensor<TxTxNxceil(H/T)xceil(W/T)xC>.

This extra dimensionality is very useful for tiling, since the iteration space is actually over the innermost of those 6 dimensions. Trying to implement a tiled implementation where some of those dimensions are collapsed is tricky.

I would suggest reworking the op semantics a bit to allow for this extra dimensionality. Then, the tiled implementation from IREE should also just drop right in here. Also, another benefit of having separate input tile dimensions is that the input tile dimensions can be innermost if desired. This can potentially provide performance benefits from having continuous accesses in the winograd ops. We have experimentally tried this in IREE, but have not merged the change yet.

The other difference I notice, is that there is no implicit padding on the input transform. This means that for shapes that are unaligned with the input tiles, the input transform will not capture the last partial tile of the input tensor. Instead the input transform should extract the last partial tile, and pad the slice with zeros to the input tile size (alpha).

The last comment is that the implementation here has explicit I64Attrs for the output height and width. This should not be necessary if the ops have padding semantics as suggested above, since there is no need to know the output shape for the input or filter transform in that case. Having this attribute means that the op is restricted to static image sizes only (making this an operand could allow some dynamic support, but it is better not to rely on having these sizes since it is not necessary).

Thanks for your review and suggestions. It's very helpful to get insight of IREE implementation. I appreciate your detail explanation. I will address your comments in the next update.

@Hsiangkai
Copy link
Contributor Author

Update the patch using 6D tensors after filter/input transformation, instead of 4D.

Define high level winograd operators and convert conv_2d_nhwc_fhwc into
winograd operators. According to Winograd Conv2D algorithm, we need
three transform operators for input, filter, and output transformation.

The formula of Winograd Conv2D algorithm is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

filter transform: G x g x G^T
input transform: B^T x d x B
output transform: A^T x y x A

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
Add a transform operator structured.winograd_conv2d to convert
linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
Convert Linalg winograd_filter_transform, winograd_input_transform, and
winograd_output_transform into nested loops with matrix multiplication
with constant transform matrices.

Support several configurations of Winograd Conv2D, including F(2, 3),
F(4, 3) and F(2, 5). These configurations show that the implementation
can support different kernel size (3 and 5) and different output size
(2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also
supports 1x3, 3x1, 1x5, and 5x1 kernels.

The implementation is based on the paper, Fast Algorithm for
Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operators. Before converting winograd
operators into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose
winograd operators. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operators.
@Hsiangkai
Copy link
Contributor Author

I think I already addressed all the comments. I split the original patch into 4 commits. I am not sure I do it in the right way or not. Should I create 4 separate PR instead of 1 PR with 4 commits? Please help me to review it again. Thank you.

@Hsiangkai Hsiangkai requested review from Max191 and ftynse June 20, 2024 10:00
@Hsiangkai
Copy link
Contributor Author

I created four PRs for Winograd Conv2D implementation.

#96176
#96177
#96178
#96179

Let's move the review there.

@Hsiangkai Hsiangkai closed this Jun 20, 2024
@Hsiangkai
Copy link
Contributor Author

Sorry, I didn't create above PRs correctly. The correct PRs are

#96181
#96182
#96183
#96184

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants