diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 64c538367267d..45726d6ee2224 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,4 +154,133 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } +def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform", + [DeclareOpInterfaceMethods]> { + let summary = "Winograd filter transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of filter + transformation (G x g x G^T) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$filter, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $filter `:` type($filter) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform", + [DeclareOpInterfaceMethods]> { + let summary = "Winograd input transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of input + transformation (B^T x d x B) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$input, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $input `:` type($input) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + +def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform", + [DeclareOpInterfaceMethods]> { + let summary = "Winograd output transform operator"; + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + This operator is defined to represent the high level concept of output + transformation (A^T x y x A) in the Winograd Conv2D algorithm. + }]; + + let arguments = (ins AnyRankedTensor:$value, + AnyRankedTensor:$output, + I64Attr:$m, + I64Attr:$r + ); + + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + attr-dict + `m` `(` $m `)` + `r` `(` $r `)` + `ins` `(` $value `:` type($value) `)` + `outs` `(` $output `:` type($output) `)` + `->` type($result) + }]; + let hasVerifier = 1; +} + #endif // LINALG_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 93e2c2db729da..71736eae38b4f 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2587,4 +2587,92 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operator into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DecomposeWinogradOp : Op { + let description = [{ + Decompose winograd operators. It will convert filter, input and output + transform operators into a combination of scf, tensor, and linalg + equivalent operators. Before applying this transform operator, users + need to tile winograd transform operators into supported sizes. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + operator. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 05e97befdec1f..d0eec2be1f8fb 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1312,6 +1312,58 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + +/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is +/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W +/// from FHWC first. We need to generate 2 levels of loops to iterate on F and +/// C. After the rewriting, we get +/// +/// scf.for %f = lo_f to hi_f step 1 +/// scf.for %c = lo_c to hi_c step 1 +/// %extracted = extract filter from filter +/// %ret = linalg.matmul G, %extracted +/// %ret = linalg.matmul %ret, GT +/// %inserted = insert %ret into filter +FailureOr +decomposeWinogradFilterTransformOp(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op); + +/// Rewrite linalg.winograd_input_transform. The data layout of the input is +/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W +/// from NHWC first. We need to generate 2 levels of loops to iterate on N and +/// C. After the rewriting, we get +/// +/// scf.for %n = lo_n to hi_n step 1 +/// scf.for %c = lo_c to hi_c step 1 +/// %extracted = extract input from input +/// %ret = linalg.matmul BT, %extracted +/// %ret = linalg.matmul %ret, B +/// %inserted = insert %ret into input +FailureOr +decomposeWinogradInputTransformOp(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op); + +/// Rewrite linalg.winograd_output_transform. The data layout of the output is +/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W +/// from HWNF first. We need to generate 2 levels of loops to iterate on N and +/// F. After the transformation, we get +/// +/// scf.for %n = lo_n to hi_n step 1 +/// scf.for %f = lo_f to hi_f step 1 +/// %extracted = extract input from result +/// %ret = linalg.matmul AT, %extracted +/// %ret = linalg.matmul %ret, A +/// %inserted = insert %ret into ret +FailureOr +decomposeWinogradOutputTransformOp(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a @@ -1692,6 +1744,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns, void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn); +/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r); + +/// Patterns to decompose Winograd operators. +void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 57d126603ebd7..a416e1f6e257f 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2734,6 +2734,365 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { return SmallVector{result}; } +//===----------------------------------------------------------------------===// +// WinogradFilterTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradFilterTransformOp::verify() { + auto filterType = cast(getFilter().getType()); + auto outputType = cast(getOutput().getType()); + auto filterElemType = filterType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (filterElemType != outputElemType) { + return emitOpError() << "expected element type of input " << filterElemType + << " to match element type of output " + << outputElemType; + } + + unsigned filterRank = filterType.getRank(); + if (filterRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +SmallVector +WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, output, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector +WinogradFilterTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder, + Location loc) { + if (auto val = opFoldResult.dyn_cast()) { + return val; + } else if (auto attr = opFoldResult.dyn_cast()) { + auto intAttr = cast(attr); + return builder.create(loc, intAttr); + } + // This should never happen if OpFoldResult is correctly formed. + return nullptr; +} + +LogicalResult WinogradFilterTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(offsets[0]); + resultOffsets.push_back(offsets[1]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[2]); + resultSizes.push_back(sizes[3]); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr WinogradFilterTransformOp::getTiledImplementation( + OpBuilder &builder, ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + + Location loc = getLoc(); + SmallVector strides(6, oneAttr); + SmallVector tiledOperands; + tiledOperands.emplace_back(getFilter()); + + SmallVector sliceOffsets, sliceSizes; + if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets, + sliceSizes))) + return failure(); + + tiledOperands.emplace_back(builder.create( + loc, getOutput(), sliceOffsets, sliceSizes, strides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + +//===----------------------------------------------------------------------===// +// WinogradInputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradInputTransformOp::verify() { + auto inputType = cast(getInput().getType()); + auto outputType = cast(getOutput().getType()); + auto inputElemType = inputType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (inputElemType != outputElemType) { + return emitOpError() << "expected element type of input " << inputElemType + << " to match element type of output " + << outputElemType; + } + + unsigned inputRank = inputType.getRank(); + if (inputRank != 4) + return emitOpError() << "expected rank of input is 4"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 6) + return emitOpError() << "expected rank of output is 6"; + + return success(); +} + +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value output = getOutput(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, output, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + resultOffsets.push_back(offsets[0]); + resultOffsets.push_back(offsets[1]); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(oneAttr); + resultSizes.push_back(sizes[2]); + resultSizes.push_back(sizes[3]); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(sizes[5]); + + return success(); +} + +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Value input = getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = inputH != 1 ? alpha : 1; + int64_t alphaW = inputW != 1 ? alpha : 1; + auto alphaHAttr = builder.getI64IntegerAttr(alphaH); + auto alphaWAttr = builder.getI64IntegerAttr(alphaW); + + Location loc = getLoc(); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffset1 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc)); + Value mappedOffset2 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc)); + + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(mappedOffset1); + sliceOffsets.push_back(mappedOffset2); + sliceOffsets.push_back(zeroAttr); + sliceSizes.push_back(sizes[4]); + sliceSizes.push_back(alphaHAttr); + sliceSizes.push_back(alphaWAttr); + sliceSizes.push_back(sizes[5]); + SmallVector inputStrides(4, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + + sliceOffsets.clear(); + sliceSizes.clear(); + if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets, + sliceSizes))) + return failure(); + + SmallVector outputStrides(6, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), sliceOffsets, sliceSizes, outputStrides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + +//===----------------------------------------------------------------------===// +// WinogradOutputTransformOp +//===----------------------------------------------------------------------===// + +LogicalResult WinogradOutputTransformOp::verify() { + auto valueType = cast(getValue().getType()); + auto outputType = cast(getOutput().getType()); + auto valueElemType = valueType.getElementType(); + auto outputElemType = outputType.getElementType(); + if (valueElemType != outputElemType) { + return emitOpError() << "expected element type of value " << valueElemType + << " to match element type of output " + << outputElemType; + } + + unsigned valueRank = valueType.getRank(); + if (valueRank != 6) + return emitOpError() << "expected rank of input is 6"; + + unsigned outputRank = outputType.getRank(); + if (outputRank != 4) + return emitOpError() << "expected rank of output is 4"; + + return success(); +} + +SmallVector +WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value value = getValue(); + SmallVector loopBounds(6); + for (unsigned dim = 0; dim < 6; ++dim) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, value, dim); + loopBounds[dim].stride = one; + } + return loopBounds; +} + +SmallVector +WinogradOutputTransformOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(6, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradOutputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + auto zeroAttr = builder.getI64IntegerAttr(0); + int64_t m = getM(); + IntegerAttr mAttr = getMAttr(); + Location loc = getLoc(); + auto context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffset1 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[0], builder, loc)); + Value mappedOffset2 = builder.create( + loc, affineMap, getValueFromOpFoldResult(offsets[1], builder, loc)); + + resultOffsets.push_back(zeroAttr); + resultOffsets.push_back(mappedOffset1); + resultOffsets.push_back(mappedOffset2); + resultOffsets.push_back(zeroAttr); + resultSizes.push_back(sizes[4]); + resultSizes.push_back(mAttr); + resultSizes.push_back(mAttr); + resultSizes.push_back(sizes[5]); + return success(); +} + +FailureOr WinogradOutputTransformOp::getTiledImplementation( + OpBuilder &builder, ArrayRef offsets, + ArrayRef sizes) { + auto oneAttr = builder.getI64IntegerAttr(1); + auto zeroAttr = builder.getI64IntegerAttr(0); + Location loc = getLoc(); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + sliceOffsets.push_back(offsets[0]); + sliceOffsets.push_back(offsets[1]); + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(zeroAttr); + sliceOffsets.push_back(zeroAttr); + sliceSizes.push_back(oneAttr); + sliceSizes.push_back(oneAttr); + sliceSizes.push_back(sizes[2]); + sliceSizes.push_back(sizes[3]); + sliceSizes.push_back(sizes[4]); + sliceSizes.push_back(sizes[5]); + SmallVector sliceStrides(6, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getValue(), sliceOffsets, sliceSizes, sliceStrides)); + + sliceOffsets.clear(); + sliceSizes.clear(); + if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets, + sliceSizes))) + return failure(); + + SmallVector strides(4, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), sliceOffsets, sliceSizes, strides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bc02788f9c441..358c15f145407 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3480,6 +3480,58 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + return winogradConv2D(rewriter, op, getM(), getR()); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto maybeTransformed = + TypeSwitch>(target) + .Case([&](linalg::WinogradFilterTransformOp op) { + return decomposeWinogradFilterTransformOp(rewriter, op); + }) + .Case([&](linalg::WinogradInputTransformOp op) { + return decomposeWinogradInputTransformOp(rewriter, op); + }) + .Case([&](linalg::WinogradOutputTransformOp op) { + return decomposeWinogradOutputTransformOp(rewriter, op); + }) + .Default([&](Operation *op) { + return rewriter.notifyMatchFailure(op, "not supported"); + }); + + if (failed(maybeTransformed)) + return emitDefaultSilenceableFailure(target); + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 7e3dc56e0acdc..a7dcc29b5b9be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Transforms.cpp TransposeConv2D.cpp Vectorization.cpp + WinogradConv2D.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp new file mode 100644 index 0000000000000..7cbd8ed9d44e8 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -0,0 +1,1118 @@ +//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implement Winograd Conv2D algorithm. The implementation is based on the +// paper: Fast Algorithms for Convolutional Neural Networks +// (https://arxiv.org/abs/1509.09308) +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace linalg { + +namespace { + +// clang-format off +// Winograd Conv2D uses a minimal 2D filtering algorithm to calculate its +// result. The formula of minimal 2D filtering algorithm F(m x m, r x r), +// m is the output dimension and r is the filter dimension, is +// +// Y = A^T x [ (G x g x G^T) x (B^T x d x B) ] x A +// +// g is filter and d is input data. We need to prepare 6 constant +// transformation matrices, G, G^T, B^T, B, A^T, and A for this formula. +// +// The following tables define these constant transformation matrices for +// F(2 x 2, 3 x 3), F(4 x 4, 3 x 3), and F(2 x 2, 5 x 5) +constexpr float G_2x2_3x3[] = { + -1, 0, 0, + 1./2, -1./2, 1./2, + 1./2, 1./2, 1./2, + 0, 0, 1 +}; + +constexpr float GT_2x2_3x3[] = { + -1, 1./2, 1./2, 0, + 0, -1./2, 1./2, 0, + 0, 1./2, 1./2, 1 +}; + +constexpr float BT_2x2_3x3[] = { + -1, 0, 1, 0, + 0, -1, 1, 0, + 0, 1, 1, 0, + 0, -1, 0, 1 +}; + +constexpr float B_2x2_3x3[] = { + -1, 0, 0, 0, + 0, -1, 1, -1, + 1, 1, 1, 0, + 0, 0, 0, 1 +}; + +constexpr float AT_2x2_3x3[] = { + 1, 1, 1, 0, + 0, -1, 1, 1 +}; + +constexpr float A_2x2_3x3[] = { + 1, 0, + 1, -1, + 1, 1, + 0, 1 +}; + +constexpr float G_4x4_3x3[] = { + 1, 0, 0, + -1./3, 1./3, -1./3, + -1./3, -1./3, -1./3, + 1./12, -1./6, 1./3, + 1./12, 1./6, 1./3, + 0, 0, 1 +}; + +constexpr float GT_4x4_3x3[] = { + 1, -1./3, -1./3, 1./12, 1./12, 0, + 0, 1./3, -1./3, -1./6, 1./6, 0, + 0, -1./3, -1./3, 1./3, 1./3, 1 +}; + +constexpr float BT_4x4_3x3[] = { + 1./4, 0, -5./16, 0, 1./16, 0, + 0, 1./4, -1./4, -1./16, 1./16, 0, + 0, -1./4, -1./4, 1./16, 1./16, 0, + 0, 1./4, -1./8, -1./4, 1./8, 0, + 0, -1./4, -1./8, 1./4, 1./8, 0, + 0, 1./4, 0, -5./16, 0, 1./16 +}; + +constexpr float B_4x4_3x3[] = { + 1./4, 0, 0, 0, 0, 0, + 0, 1./4, -1./4, 1./4, -1./4, 1./4, + -5./16, -1./4, -1./4, -1./8, -1./8, 0, + 0, -1./16, 1./16, -1./4, 1./4, -5./16, + 1./16, 1./16, 1./16, 1./8, 1./8, 0, + 0, 0, 0, 0, 0, 1./16 +}; + +constexpr float AT_4x4_3x3[] = { + 1./8, 1./4, 1./4, 1./8, 1./8, 0, + 0, -1./4, 1./4, -1./4, 1./4, 0, + 0, 1./4, 1./4, 1./2, 1./2, 0, + 0, -1./4, 1./4, -1, 1, 1./2 +}; + +constexpr float A_4x4_3x3[] = { + 1./8, 0, 0, 0, + 1./4, -1./4, 1./4, -1./4, + 1./4, 1./4, 1./4, 1./4, + 1./8, -1./4, 1./2, -1, + 1./8, 1./4, 1./2, 1, + 0, 0, 0, 1./2 +}; + +constexpr float G_2x2_5x5[] = { + 1, 0, 0, 0, 0, + 1./6, -1./6, 1./6, -1./6, 1./6, + -1./6, -1./6, -1./6, -1./6, -1./6, +-4./15, 2./15, -1./15, 1./30, -1./60, + 1./60, 1./30, 1./15, 2./15, 4./15, + 0, 0, 0, 0, 1 +}; + +constexpr float GT_2x2_5x5[] = { + 1, 1./6, -1./6, -4./15, 1./60, 0, + 0, -1./6, -1./6, 2./15, 1./30, 0, + 0, 1./6, -1./6, -1./15, 1./15, 0, + 0, -1./6, -1./6, 1./30, 2./15, 0, + 0, 1./6, -1./6, -1./60, 4./15, 1 +}; + +constexpr float BT_2x2_5x5[] = { + 1./8, 3./16, -1./4, -3./16, 1./8, 0, + 0, 1./8, 1./16, -5./16, 1./8, 0, + 0, -1./8, -5./16, -1./16, 1./8, 0, + 0, 1./4, -1./8, -1./4, 1./8, 0, + 0, -1./8, -1./4, 1./8, 1./4, 0, + 0, 1./8, 3./16, -1./4, -3./16, 1./8 +}; + +constexpr float B_2x2_5x5[] = { + 1./8, 0, 0, 0, 0, 0, + 3./16, 1./8, -1./8, 1./4, -1./8, 1./8, + -1./4, 1./16, -5./16, -1./8, -1./4, 3./16, + -3./16, -5./16, -1./16, -1./4, 1./8, -1./4, + 1./8, 1./8, 1./8, 1./8, 1./4, -3./16, + 0, 0, 0, 0, 0, 1./8 +}; + +constexpr float AT_2x2_5x5[] = { + 1./2, 1, 1, 2, 1, 0, + 0, -1, 1, -1, 2, 1./2 +}; + +constexpr float A_2x2_5x5[] = { + 1./2, 0, + 1, -1, + 1, 1, + 2, -1, + 1, 2, + 0, 1./2 +}; +// clang-format on + +using TransformMapKeyTy = std::pair; + +// We use F(m, r) to define the size of minimal filtering algorithms. +// m is the output dimension and r is the filter dimension. We can get +// the input dimension, alpha, from the formula, alpha = m + r - 1. +// +// For example, when m = 2 and r = 3, we know its input size is 4. +// The Conv2D will operate on 4x4 input data with 3x3 filter and get +// 2x2 output result. +constexpr TransformMapKeyTy F_2_3{2, 3}; +constexpr TransformMapKeyTy F_4_3{4, 3}; +constexpr TransformMapKeyTy F_2_5{2, 5}; + +struct TransformMatrix { + TransformMatrix(const float *table, int64_t rows, int64_t cols, + int64_t scalarFactor = 1) + : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} + + const float *table; + int64_t rows; + int64_t cols; + int64_t scalarFactor; +}; + +Value create2DTransformMatrix(RewriterBase &rewriter, Location loc, + TransformMatrix transform, Type type) { + ArrayRef const_vec(transform.table, transform.rows * transform.cols); + + return rewriter.create( + loc, DenseFPElementsAttr::get( + RankedTensorType::get( + SmallVector{transform.rows, transform.cols}, type), + const_vec)); +} + +Value extract2DData(RewriterBase &rewriter, Location loc, Value source, + Value outLoopIndex, Value inLoopIndex, int64_t outLoopIdx, + int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx, + int64_t srcSize) { + auto sourceType = cast(source.getType()); + Type elementType = sourceType.getElementType(); + auto sourceShape = sourceType.getShape(); + int64_t height = sourceShape[heightIdx]; + int64_t width = sourceShape[widthIdx]; + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(srcSize, zeroIndex); + offsets[outLoopIdx] = outLoopIndex; + offsets[inLoopIdx] = inLoopIndex; + SmallVector sizes(srcSize, oneIndex); + sizes[heightIdx] = rewriter.getIndexAttr(height); + sizes[widthIdx] = rewriter.getIndexAttr(width); + SmallVector strides(srcSize, oneIndex); + SmallVector targetShape(srcSize, 1); + targetShape[heightIdx] = height; + targetShape[widthIdx] = width; + + auto targetType = RankedTensorType::get(targetShape, elementType); + auto extractFilterOp = rewriter.create( + loc, targetType, source, offsets, sizes, strides); + + auto extractFilterType = RankedTensorType::get({height, width}, elementType); + auto extractFilter = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, extractFilterOp, extractFilterType); + + return extractFilter; +} + +Value insert2DData(RewriterBase &rewriter, Location loc, Value source, + Value dest, Value outLoopIndex, Value inLoopIndex, + int64_t height, int64_t width, int64_t outLoopIdx, + int64_t inLoopIdx, int64_t heightIdx, int64_t widthIdx, + int64_t destSize) { + auto sourceType = cast(source.getType()); + Type elementType = sourceType.getElementType(); + SmallVector sliceShape(destSize, 1); + sliceShape[heightIdx] = height; + sliceShape[widthIdx] = width; + auto init = rewriter.create(loc, sliceShape, elementType); + auto result = tensor::createCanonicalRankReducingInsertSliceOp(rewriter, loc, + source, init); + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector retOffsets(destSize, zeroIndex); + retOffsets[outLoopIdx] = outLoopIndex; + retOffsets[inLoopIdx] = inLoopIndex; + SmallVector retSizes(destSize, oneIndex); + retSizes[heightIdx] = rewriter.getIndexAttr(height); + retSizes[widthIdx] = rewriter.getIndexAttr(width); + SmallVector strides(destSize, oneIndex); + + auto insertSliceOp = rewriter.create( + loc, result, dest, retOffsets, retSizes, strides); + + return insertSliceOp; +} + +Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { + auto type = cast(data.getType()); + auto elementType = type.getElementType(); + auto shape = type.getShape(); + auto collapseType = RankedTensorType::get( + {shape[0] * shape[1] * shape[2] * shape[3], shape[4], shape[5]}, + elementType); + SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; + return rewriter.create(loc, collapseType, data, + reassociation); +} + +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter from filter +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) + return Value(); + if (filterW != r && filterW != 1) + return Value(); + + // Return shape is + auto zeroIdx = rewriter.create(loc, 0); + auto fUpperBound = rewriter.create(loc, filterF); + auto cUpperBound = rewriter.create(loc, filterC); + auto oneStep = rewriter.create(loc, 1); + auto outerForOp = + rewriter.create(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value CIter = innerForBody->getArgument(0); + + // Extract (H, W) from (F, H, W, C) + auto extractFilter = extract2DData( + rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0, + /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + Value matmulRetValue = extractFilter; + if (leftTransform) { + // Get constant transform matrix G + auto it = GMatrices.find(key); + if (it == GMatrices.end()) + return Value(); + const TransformMatrix &GMatrix = it->second; + + retRows = GMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + + Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType); + // Multiply G x g + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix GT + auto it = GTMatrices.find(key); + if (it == GTMatrices.end()) + return Value(); + const TransformMatrix >Matrix = it->second; + + auto matmulType = + RankedTensorType::get({retRows, GTMatrix.cols}, elementType); + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + + Value GT = create2DTransformMatrix(rewriter, loc, GTMatrix, elementType); + // Multiply u = (G x g) x GT + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Insert (H, W) to (1, 1, H, W, C, F) + Value iterArg = innerForOp.getRegionIterArgs()[0]; + int64_t retHeight = leftTransform ? m + r - 1 : 1; + int64_t retWidth = rightTransform ? m + r - 1 : 1; + auto insertSliceOp = insert2DData( + rewriter, loc, matmulRetValue, iterArg, FIter, CIter, retHeight, retWidth, + /*outLoopIdx=*/5, /*inLoopIdx=*/4, /*heightIdx=*/2, /*widthIdx=*/3, + /*destSize=*/6); + + rewriter.create(loc, insertSliceOp); + + rewriter.setInsertionPointToEnd(outerForBody); + rewriter.create(loc, innerForOp.getResult(0)); + + rewriter.setInsertionPointAfter(outerForOp); + + return outerForOp.getResult(0); +} + +// This function transforms the input. The data layout of the input is NHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// NHWC first. We need to generate 2 levels of loops to iterate on N and C. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract input from input +// %ret = linalg.matmul BT, %extracted +// %ret = linalg.matmul %ret, B +// %inserted = insert %ret into input +// +Value inputTransform(RewriterBase &rewriter, Location loc, Value input, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to BT transform matrix. + static const llvm::SmallDenseMap + BTMatrices = { + {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)}, + {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)}, + {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)}, + }; + + // Map from (m, r) to B transform matrix. + static const llvm::SmallDenseMap + BMatrices = { + {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)}, + {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)}, + {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)}, + }; + + auto inputType = cast(input.getType()); + Type elementType = inputType.getElementType(); + auto inputShape = inputType.getShape(); // N, H, W, C + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (inputH != alphaH && inputH != 1) + return Value(); + if (inputW != alphaW && inputW != 1) + return Value(); + + auto zeroIdx = rewriter.create(loc, 0); + auto nUpperBound = rewriter.create(loc, inputN); + auto cUpperBound = rewriter.create(loc, inputC); + auto oneStep = rewriter.create(loc, 1); + + auto outerForOp = + rewriter.create(loc, zeroIdx, nUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value CIter = innerForBody->getArgument(0); + + // Extract (H, W) from (N, H, W, C) + auto extractInput = extract2DData( + rewriter, loc, input, NIter, CIter, /*outLoopIdx=*/0, + /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + Value matmulRetValue = extractInput; + if (leftTransform) { + // Get constant transform matrix BT + auto it = BTMatrices.find(key); + if (it == BTMatrices.end()) + return Value(); + const TransformMatrix &BTMatrix = it->second; + + retRows = BTMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, inputW}, elementType); + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + + Value BT = + create2DTransformMatrix(rewriter, loc, BTMatrix, rewriter.getF32Type()); + // Multiply BT x d + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix B + auto it = BMatrices.find(key); + if (it == BMatrices.end()) + return Value(); + const TransformMatrix &BMatrix = it->second; + + retCols = BMatrix.cols; + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + Value B = + create2DTransformMatrix(rewriter, loc, BMatrix, rewriter.getF32Type()); + // Multiply v = (BT x d) x B + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Insert v + // Insert (H, W) to (1, 1, H, W, N, C) + Value iterArg = innerForOp.getRegionIterArgs()[0]; + auto combinedVal = insert2DData( + rewriter, loc, matmulRetValue, iterArg, NIter, CIter, retRows, retCols, + /*outLoopIdx=*/4, /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, + /*destSize=*/6); + + rewriter.create(loc, combinedVal); + + rewriter.setInsertionPointToEnd(outerForBody); + rewriter.create(loc, innerForOp.getResult(0)); + + rewriter.setInsertionPointAfter(outerForOp); + + return outerForOp.getResult(0); +} + +// This function generates linalg.batch_matmul to multiply input with filter. +// linalg.batch_matmul only supports 3-dimension data sets. We can treat +// tileH x tileW x H x W data as the 1-dimension data array. That is to convert +// [tileH, tileW, H, W, N, C] to [tileH x tileW x H x W, N, C]. In this way, we +// can convert 6-dimension input data to 3-dimension representation that is +// suitable for linalg.batch_matmul. +// +// Batched matmul will do the matrix multiply with the reduction on channel. +// +// We get +// +// %collapsed_input = tensor.collapse_shape %input +// %collapsed_filter = tensor.collapse_shape %filter +// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter +// %expanded_ret = tensor.expand_shape %ret +// +// After this function, we get return value with data layout +// (tileH, tileW, H, W, N, F). +Value matrixMultiply(RewriterBase &rewriter, Location loc, + Value transformedFilter, Value transformedInput) { + auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter); + auto collapseInput = collapse2DData(rewriter, loc, transformedInput); + + // Batched matrix multiply + auto filterType = cast(transformedFilter.getType()); + auto filterShape = filterType.getShape(); + auto inputType = cast(transformedInput.getType()); + auto inputElemType = inputType.getElementType(); + auto inputShape = inputType.getShape(); + + auto matmulType = RankedTensorType::get( + {inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3], + inputShape[4], filterShape[5]}, + inputElemType); + Value init = rewriter.create(loc, matmulType.getShape(), + inputElemType); + + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange({collapseInput, collapseFilter}), + ValueRange{init}); + + // Expand matmul result + SmallVector reassociation = {{0, 1, 2, 3}, {4}, {5}}; + auto expandType = + RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], + inputShape[3], inputShape[4], filterShape[5]}, + inputElemType); + auto expandOutput = rewriter.create( + loc, expandType, matmulOp.getResult(0), reassociation); + return expandOutput; +} + +// This function transforms the output. The data layout of the output is HWNF. +// The transformation matrix is 2-dimension. We need to extract H x W from +// HWNF first. We need to generate 2 levels of loops to iterate on N and F. +// After the transformation, we get +// +// scf.for %n = lo_n to hi_n step 1 +// scf.for %f = lo_f to hi_f step 1 +// %extracted = extract input from result +// %ret = linalg.matmul AT, %extracted +// %ret = linalg.matmul %ret, A +// %inserted = insert %ret into ret +// +Value outputTransform(RewriterBase &rewriter, Location loc, Value value, + Value output, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to AT transform matrix. + static const llvm::SmallDenseMap + ATMatrices = { + {F_2_3, TransformMatrix(AT_2x2_3x3, 2, 4)}, + {F_4_3, TransformMatrix(AT_4x4_3x3, 4, 6, 32)}, + {F_2_5, TransformMatrix(AT_2x2_5x5, 2, 6, 16)}, + }; + + // Map from (m, r) to A transform matrix. + static const llvm::SmallDenseMap + AMatrices = { + {F_2_3, TransformMatrix(A_2x2_3x3, 4, 2)}, + {F_4_3, TransformMatrix(A_4x4_3x3, 6, 4, 32)}, + {F_2_5, TransformMatrix(A_2x2_5x5, 6, 2, 16)}, + }; + + auto valueType = cast(value.getType()); + Type elementType = valueType.getElementType(); + auto valueShape = valueType.getShape(); // TileH, TileW, H, W, N, F + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + int64_t valueN = valueShape[4]; + int64_t valueF = valueShape[5]; + int64_t alphaH = leftTransform ? m + r - 1 : 1; + int64_t alphaW = rightTransform ? m + r - 1 : 1; + + if (valueH != alphaH && valueH != 1) + return Value(); + if (valueW != alphaW && valueW != 1) + return Value(); + + auto zeroIdx = rewriter.create(loc, 0); + auto nUpperBound = rewriter.create(loc, valueN); + auto fUpperBound = rewriter.create(loc, valueF); + auto oneStep = rewriter.create(loc, 1); + + auto outerForOp = + rewriter.create(loc, zeroIdx, nUpperBound, oneStep, output); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value NIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create( + loc, zeroIdx, fUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value FIter = innerForBody->getArgument(0); + + // Extract (H, W) from (1, 1, H, W, N, F) + auto extractValue = extract2DData( + rewriter, loc, value, NIter, FIter, /*outLoopIdx=*/4, + /*inLoopIdx=*/5, /*heightIdx=*/2, /*widthIdx=*/3, /*srcSize=*/6); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + int64_t retCols = 1; + int64_t leftScalarFactor = 1; + int64_t rightScalarFactor = 1; + Value matmulRetValue = extractValue; + if (leftTransform) { + // Get constant transform matrix AT + auto it = ATMatrices.find(key); + if (it == ATMatrices.end()) + return Value(); + const TransformMatrix &ATMatrix = it->second; + + leftScalarFactor = ATMatrix.scalarFactor; + retRows = ATMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + + Value AT = create2DTransformMatrix(rewriter, loc, ATMatrix, elementType); + // Multiply AT x m + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + if (rightTransform) { + // Get constant transform matrix T + auto it = AMatrices.find(key); + if (it == AMatrices.end()) + return Value(); + const TransformMatrix &AMatrix = it->second; + + rightScalarFactor = AMatrix.scalarFactor; + auto matmulType = + RankedTensorType::get({retRows, AMatrix.cols}, elementType); + retCols = AMatrix.cols; + auto init = rewriter.create(loc, matmulType.getShape(), + elementType); + + Value A = create2DTransformMatrix(rewriter, loc, AMatrix, elementType); + // Multiply y = (AT x m) x A + auto matmulOp = rewriter.create( + loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); + matmulRetValue = matmulOp.getResult(0); + } + + // Multiply scalar factor. + Value scalarFactor = rewriter.create( + loc, FloatAttr::get(elementType, leftScalarFactor * rightScalarFactor)); + auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); + auto init = + rewriter.create(loc, matmulType.getShape(), elementType); + + auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); + SmallVector affineMaps = {AffineMap::get(2, 0, init.getContext()), + identityAffineMap, identityAffineMap}; + auto scalarMatrixOp = rewriter.create( + loc, matmulType, ValueRange{scalarFactor, matmulRetValue}, + ValueRange{init}, affineMaps, tosa::getNParallelLoopsAttrs(2), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value scalarVal = args[0]; + Value matrixVal = args[1]; + Value result = nestedBuilder.create(nestedLoc, scalarVal, + matrixVal); + nestedBuilder.create(nestedLoc, result); + }); + + // Insert slice y + // Insert (H, W) to (N, H, W, F) + Value iterArg = innerForOp.getRegionIterArgs()[0]; + Value combinedVal = insert2DData(rewriter, loc, scalarMatrixOp.getResult(0), + iterArg, NIter, FIter, retRows, retCols, + /*outLoopIdx=*/0, + /*inLoopIdx=*/3, /*heightIdx=*/1, + /*widthIdx=*/2, /*destSize=*/4); + + rewriter.create(loc, combinedVal); + + rewriter.setInsertionPointToEnd(outerForBody); + rewriter.create(loc, innerForOp.getResult(0)); + + rewriter.setInsertionPointAfter(outerForOp); + + return outerForOp.getResult(0); +} + +Value insertToAlignedTensor(RewriterBase &rewriter, Location loc, Value value, + RankedTensorType alignedType) { + Value alignedInput = rewriter.create( + loc, alignedType.getShape(), alignedType.getElementType()); + + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + auto valueType = cast(value.getType()); + auto valueShape = valueType.getShape(); + SmallVector sizes; + sizes.emplace_back(rewriter.getIndexAttr(valueShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(valueShape[3])); + + return rewriter.create(loc, value, alignedInput, + offsets, sizes, strides); +} + +Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, + Value value, RankedTensorType extractedType) { + auto zeroIndex = rewriter.getIndexAttr(0); + auto oneIndex = rewriter.getIndexAttr(1); + SmallVector offsets(4, zeroIndex); + SmallVector strides(4, oneIndex); + + auto extractedShape = extractedType.getShape(); + SmallVector sizes; + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[0])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[1])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[2])); + sizes.emplace_back(rewriter.getIndexAttr(extractedShape[3])); + + return rewriter.create(loc, extractedType, value, + offsets, sizes, strides); +} + +bool hasAllOneValues(DenseIntElementsAttr attr) { + return llvm::all_of( + attr, [](const APInt &element) { return element.getSExtValue() == 1; }); +} + +FailureOr winogradConv2DHelper(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp convOp, + int64_t m, int64_t r) { + Value input = convOp.getInputs()[0]; + Value filter = convOp.getInputs()[1]; + Value output = convOp.getOutputs()[0]; + auto inputType = cast(input.getType()); + auto filterType = cast(filter.getType()); + auto outputType = cast(output.getType()); + + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(convOp, + "expected a static shape for the input"); + + if (!filterType.hasStaticShape()) + return rewriter.notifyMatchFailure( + convOp, "expected a static shape for the filter"); + + if (!hasAllOneValues(convOp.getDilations())) + return rewriter.notifyMatchFailure(convOp, + "expected all ones for dilations"); + + if (!hasAllOneValues(convOp.getStrides())) + return rewriter.notifyMatchFailure(convOp, "expected all ones for strides"); + + auto filterShape = filterType.getShape(); + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + auto inputShape = inputType.getShape(); + int64_t inputN = inputShape[0]; + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + int64_t inputC = inputShape[3]; + auto outputShape = outputType.getShape(); + int64_t outputN = outputShape[0]; + int64_t outputH = outputShape[1]; + int64_t outputW = outputShape[2]; + int64_t outputF = outputShape[3]; + + // Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r) + bool isSupportedFilter = false; + if (filterH == filterW && filterH == r) + isSupportedFilter = true; + if (filterH == r && filterW == 1) + isSupportedFilter = true; + if (filterH == 1 && filterW == r) + isSupportedFilter = true; + + if (!isSupportedFilter) + return rewriter.notifyMatchFailure( + convOp, "only support filter (r x r), (r x 1) or (1 x r)"); + + // Currently, we support (m, r) = (2, 3) or (4, 3) or (2, 5) + static const llvm::SmallVector validConfigs = { + F_2_3, F_4_3, F_2_5}; + + TransformMapKeyTy key = {m, r}; + auto it = std::find(validConfigs.begin(), validConfigs.end(), key); + // If we cannot find the constant transformation matrix, it means we do + // not support this configuration yet. + if (it == validConfigs.end()) + return failure(); + + // All the criterias are satisfied. We can do Winograd Conv2D. + Location loc = convOp.getLoc(); + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + int64_t heightM = leftTransform ? m : 1; + int64_t widthM = rightTransform ? m : 1; + int64_t heightR = leftTransform ? r : 1; + int64_t widthR = rightTransform ? r : 1; + + // --- Create operator for filter transform --- + Type elementType = filterType.getElementType(); + int64_t alphaH = heightM + heightR - 1; + int64_t alphaW = widthM + widthR - 1; + int64_t tileH = llvm::divideCeilSigned(outputH, heightM); + int64_t tileW = llvm::divideCeilSigned(outputW, widthM); + auto retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, filterC, filterF}, elementType); + Value retValue = + rewriter.create(loc, retType.getShape(), elementType); + auto transformedFilter = rewriter.create( + loc, retType, filter, retValue, m, r); + + // --- Create operator for input transform --- + + // When input size - (r - 1) is not aligned with output tile size, we need to + // pad the input data to create the full tiles as tiling. + int64_t alignedInputH = tileH * heightM + (heightR - 1); + int64_t alignedInputW = tileW * widthM + (widthR - 1); + if (alignedInputH != inputH || alignedInputW != inputW) { + auto alignedInputType = RankedTensorType::get( + {inputN, alignedInputH, alignedInputW, inputC}, elementType); + input = insertToAlignedTensor(rewriter, loc, input, alignedInputType); + } + + retType = RankedTensorType::get( + {tileH, tileW, alphaH, alphaW, inputN, inputC}, elementType); + retValue = + rewriter.create(loc, retType.getShape(), elementType); + auto transformedInput = rewriter.create( + loc, retType, input, retValue, m, r); + + Value matmulRet = + matrixMultiply(rewriter, loc, transformedFilter, transformedInput); + + // --- Create operator for output transform --- + + // When output size is not aligned with output tile size, we need to pad the + // output buffer to insert the full tiles after tiling. + int64_t alignedOutputH = tileH * heightM; + int64_t alignedOutputW = tileW * widthM; + bool isOutputUnaligned = + ((alignedOutputH != outputH) || (alignedOutputW != outputW)); + if (isOutputUnaligned) { + auto alignedOutputType = RankedTensorType::get( + {outputN, alignedOutputH, alignedOutputW, outputF}, elementType); + output = insertToAlignedTensor(rewriter, loc, output, alignedOutputType); + outputType = alignedOutputType; + } + + Value transformedOutput = rewriter.create( + loc, outputType, matmulRet, output, m, r); + + // When output size is not aligned with output tile size, extract the + // value from the padded buffer. + if (isOutputUnaligned) { + transformedOutput = extractFromAlignedTensor( + rewriter, loc, transformedOutput, + RankedTensorType::get({outputN, outputH, outputW, outputF}, + elementType)); + } + + rewriter.replaceOp(convOp, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +FailureOr +decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + Location loc = op.getLoc(); + Value filter = op.getFilter(); + auto filterType = cast(filter.getType()); + auto filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + Value transformedFilter = + filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedFilter) + return failure(); + + rewriter.replaceOp(op, transformedFilter); + + return transformedFilter.getDefiningOp(); +} + +FailureOr +decomposeWinogradInputTransformHelper(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + Location loc = op.getLoc(); + Value input = op.getInput(); + auto inputType = cast(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = inputH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = inputW != 1; + Value transformedInput = + inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedInput) + return failure(); + + rewriter.replaceOp(op, transformedInput); + + return transformedInput.getDefiningOp(); +} + +FailureOr +decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + Location loc = op.getLoc(); + Value value = op.getValue(); + auto valueType = cast(value.getType()); + auto valueShape = valueType.getShape(); + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = valueH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = valueW != 1; + Value transformedOutput = + outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedOutput) + return failure(); + + rewriter.replaceOp(op, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +class DecomposeWinogradFilterTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, + PatternRewriter &rewriter) const override { + if (failed(decomposeWinogradFilterTransformHelper(rewriter, op))) + return failure(); + + return success(); + } +}; + +class DecomposeWinogradInputTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, + PatternRewriter &rewriter) const override { + if (failed(decomposeWinogradInputTransformHelper(rewriter, op))) + return failure(); + + return success(); + } +}; + +class DecomposeWinogradOutputTransform final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradOutputTransformOp op, + PatternRewriter &rewriter) const override { + if (failed(decomposeWinogradOutputTransformHelper(rewriter, op))) + return failure(); + + return success(); + } +}; + +class WinogradConv2DNhwcFhwc final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r) + : OpRewritePattern(context), m(m), r(r) {} + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp, + PatternRewriter &rewriter) const override { + if (failed(winogradConv2DHelper(rewriter, convOp, m, r))) + return failure(); + + return success(); + } + +private: + int64_t m; + int64_t r; +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + +FailureOr +decomposeWinogradFilterTransformOp(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + return decomposeWinogradFilterTransformHelper(rewriter, op); +} + +FailureOr +decomposeWinogradInputTransformOp(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + return decomposeWinogradInputTransformHelper(rewriter, op); +} + +FailureOr +decomposeWinogradOutputTransformOp(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + return decomposeWinogradOutputTransformHelper(rewriter, op); +} + +void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, + int64_t r) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context, m, r); +} + +void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.insert(context); + patterns.insert(context); + patterns.insert(context); +} + +} // end namespace linalg +} // end namespace mlir diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir new file mode 100644 index 0000000000000..39aeea1770101 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir @@ -0,0 +1,332 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = tensor.empty() : tensor<2x2x6x6x5x2xf32> + %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> + %4 = tensor.empty() : tensor<2x2x6x6x2x5xf32> + %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%4 : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> + %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> + %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> + %6 = tensor.empty() : tensor<144x2x2xf32> + %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%6 : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> + %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> + %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<2x2x6x6x2x2xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %8 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.decompose_winograd_op %6 : (!transform.any_op) -> (!transform.any_op) + %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op + %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op) + %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op + %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32> +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32> +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<2x2x6x6x5x2xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x2x6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x5x2xf32> to tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<6x3xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, 0, 0, %[[ARG9]], %[[ARG7]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_10]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S18]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x5x2xf32> into tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[S5:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<2x2x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x2x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x6x6x5xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x2x5xf32> to tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE]]_9 : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_11]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S18]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x5xf32> into tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S8:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S8]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x8x8x2xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x8x8x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<2x2x6x6x2x2xf32> to tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S1]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S17:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK-NEXT: %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32 +// CHECK-NEXT: linalg.yield %[[S24]] : f32 +// CHECK-NEXT: } -> tensor<4x4xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S16]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x8x8x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[S11]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = tensor.empty() : tensor<3x3x6x6x5x2xf32> + %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> + %4 = tensor.empty() : tensor<2x14x14x5xf32> + %inserted_slice = tensor.insert_slice %arg0 into %4[0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> + %5 = tensor.empty() : tensor<3x3x6x6x2x5xf32> + %6 = linalg.winograd_input_transform m(4) r(3) ins(%inserted_slice : tensor<2x14x14x5xf32>) outs(%5 : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> + %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> + %collapsed_0 = tensor.collapse_shape %6 [[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> + %7 = tensor.empty() : tensor<324x2x2xf32> + %8 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%7 : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> + %expanded = tensor.expand_shape %8 [[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> + %9 = tensor.empty() : tensor<2x12x12x2xf32> + %inserted_slice_1 = tensor.insert_slice %1 into %9[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> + %10 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<3x3x6x6x2x2xf32>) outs(%inserted_slice_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> + %extracted_slice = tensor.extract_slice %10[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> + return %extracted_slice : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [1, 1, 0, 0, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %6 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %1 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.decompose_winograd_op %6 : (!transform.any_op) -> (!transform.any_op) + %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op + %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op) + %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op + %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32> +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32> +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) -> (tensor<3x3x6x6x5x2xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<3x3x6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S2]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x5x2xf32> to tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_7]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<6x3xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_8]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S19]] : tensor<6x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[INSERTED_SLICE_9]] into %[[ARG10]][0, 0, 0, 0, %[[ARG9]], %[[ARG7]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_10]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S18]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 5, 2] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x5x2xf32> into tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_INPUT_BUF:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S5:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S6]]) -> (tensor<3x3x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<3x3x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[INSERTED_INPUT_BUF]][0, %[[S13]], %[[S14]], 0] [2, 6, 6, 5] [1, 1, 1, 1] : tensor<2x14x14x5xf32> to tensor<2x6x6x5xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S5]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x2x5xf32> to tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S18:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_9]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S19]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.matmul ins(%[[S20]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S21]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_11]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S18]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x5xf32> into tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %4 {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S8:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S9:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S8]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S9]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_OUTPUT_BUF:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S11:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S10]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[S12:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x12x12x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][%[[ARG3]], %[[ARG5]], 0, 0, 0, 0] [1, 1, 6, 6, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<3x3x6x6x2x2xf32> to tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S13:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S14:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[INSERTED_OUTPUT_BUF]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S15:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[S16:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_8]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S17:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK-NEXT: %[[S18:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_9]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S17]] : tensor<4x6xf32>) -> tensor<4x6xf32> +// CHECK-NEXT: %[[S19:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S20:.*]] = linalg.matmul ins(%[[S18]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S19]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[S21:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S22:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]], #[[$MAP4]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S20]] : f32, tensor<4x4xf32>) outs(%[[S21]] : tensor<4x4xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN_12:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[S24:.*]] = arith.mulf %[[IN]], %[[IN_12]] : f32 +// CHECK-NEXT: linalg.yield %[[S24]] : f32 +// CHECK-NEXT: } -> tensor<4x4xf32> +// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_10:.*]] = tensor.insert_slice %[[S22]] into %[[S23]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_11:.*]] = tensor.insert_slice %[[INSERTED_SLICE_10]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_11]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S16]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[S16:.*]] = affine.apply #[[$MAP2]](%[[ARG3]]) +// CHECK-NEXT: %[[S17:.*]] = affine.apply #[[$MAP2]](%[[ARG5]]) +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[ARG6]][0, %[[S16]], %[[S17]], 0] [2, 4, 4, 2] [1, 1, 1, 1] : tensor<2x4x4x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S12]] : tensor<2x12x12x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[RET:.*]] = tensor.extract_slice %[[S11]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32 +// CHECK-NEXT: return %[[RET]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..1e74fea5a1c31 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir new file mode 100644 index 0000000000000..917d089c1981c --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-decompose-winograd-ops | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3) -> (0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = tensor.empty() : tensor<1x1x6x6x5x2xf32> + %3 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%2 : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> + %4 = tensor.empty() : tensor<1x1x6x6x2x5xf32> + %5 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x6x5xf32>) outs(%4 : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> + %collapsed = tensor.collapse_shape %3 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> + %collapsed_0 = tensor.collapse_shape %5 [[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> + %6 = tensor.empty() : tensor<36x2x2xf32> + %7 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%6 : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> + %expanded = tensor.expand_shape %7 [[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> + %8 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<1x1x6x6x2x2xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %8 : tensor<2x4x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00], [2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01], [2.500000e-01, 2.500000e-01, 2.500000e-01, 2.500000e-01], [1.250000e-01, -2.500000e-01, 5.000000e-01, -1.000000e+00], [1.250000e-01, 2.500000e-01, 5.000000e-01, 1.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 5.000000e-01]]> : tensor<6x4xf32> +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<{{\[}}[1.250000e-01, 2.500000e-01, 2.500000e-01, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 2.500000e-01, 5.000000e-01, 5.000000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, 2.500000e-01, -1.000000e+00, 1.000000e+00, 5.000000e-01]]> : tensor<4x6xf32> +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, 2.500000e-01, -2.500000e-01, 2.500000e-01], [-3.125000e-01, -2.500000e-01, -2.500000e-01, -1.250000e-01, -1.250000e-01, 0.000000e+00], [0.000000e+00, -6.250000e-02, 6.250000e-02, -2.500000e-01, 2.500000e-01, -3.125000e-01], [6.250000e-02, 6.250000e-02, 6.250000e-02, 1.250000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<{{\[}}[2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -2.500000e-01, -6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, -2.500000e-01, -2.500000e-01, 6.250000e-02, 6.250000e-02, 0.000000e+00], [0.000000e+00, 2.500000e-01, -1.250000e-01, -2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, -2.500000e-01, -1.250000e-01, 2.500000e-01, 1.250000e-01, 0.000000e+00], [0.000000e+00, 2.500000e-01, 0.000000e+00, -3.125000e-01, 0.000000e+00, 6.250000e-02]]> : tensor<6x6xf32> +// CHECK-DAG: %[[CST_4:.*]] = arith.constant dense<{{\[}}[1.000000e+00, -0.333333343, -0.333333343, 0.0833333358, 0.0833333358, 0.000000e+00], [0.000000e+00, 0.333333343, -0.333333343, -0.166666672, 0.166666672, 0.000000e+00], [0.000000e+00, -0.333333343, -0.333333343, 0.333333343, 0.333333343, 1.000000e+00]]> : tensor<3x6xf32> +// CHECK-DAG: %[[CST_5:.*]] = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00, 0.000000e+00], [-0.333333343, 0.333333343, -0.333333343], [-0.333333343, -0.333333343, -0.333333343], [0.0833333358, -0.166666672, 0.333333343], [0.0833333358, 0.166666672, 0.333333343], [0.000000e+00, 0.000000e+00, 1.000000e+00]]> : tensor<6x3xf32> +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x5x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<1x3x3x1xf32> to tensor<3x3xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<6x3xf32> +// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_5]], %[[EXTRACTED_SLICE_7]] : tensor<6x3xf32>, tensor<3x3xf32>) outs(%[[S10]] : tensor<6x3xf32>) -> tensor<6x3xf32> +// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_4]] : tensor<6x3xf32>, tensor<3x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG5]], %[[ARG3]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S9]] : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S4]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<1x1x6x6x2x5xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<2x6x6x5xf32> to tensor<1x6x6x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0] [1, 6, 6, 1] [1, 1, 1, 1] : tensor<1x6x6x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_3]], %[[EXTRACTED_SLICE_7]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<6x6xf32> +// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_2]] : tensor<6x6xf32>, tensor<6x6xf32>) outs(%[[S12]] : tensor<6x6xf32>) -> tensor<6x6xf32> +// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[S14]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6xf32> into tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> into tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_8]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S9]] : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_6]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S1]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) -> (tensor<2x4x4x2xf32>) { +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x2x2xf32> to tensor<1x1x6x6x1x1xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, 0, 0] [1, 1, 6, 6, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x6x6x1x1xf32> to tensor<6x6xf32> +// CHECK-NEXT: %[[S10:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK-NEXT: %[[S11:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_7]] : tensor<4x6xf32>, tensor<6x6xf32>) outs(%[[S10]] : tensor<4x6xf32>) -> tensor<4x6xf32> +// CHECK-NEXT: %[[S12:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S13:.*]] = linalg.matmul ins(%[[S11]], %[[CST_0]] : tensor<4x6xf32>, tensor<6x4xf32>) outs(%[[S12]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK-NEXT: %[[S14:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK-NEXT: %[[S15:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP3]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]], %[[S13]] : f32, tensor<4x4xf32>) outs(%[[S14]] : tensor<4x4xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[IN_9:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: %[[S17:.*]] = arith.mulf %[[IN]], %[[IN_9]] : f32 +// CHECK-NEXT: linalg.yield %[[S17]] : f32 +// CHECK-NEXT: } -> tensor<4x4xf32> +// CHECK-NEXT: %[[S16:.*]] = tensor.empty() : tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S15]] into %[[S16]][0, 0, 0, 0] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<1x4x4x1xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_8:.*]] = tensor.insert_slice %[[INSERTED_SLICE]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<1x4x4x1xf32> into tensor<2x4x4x2xf32> +// CHECK-NEXT: scf.yield %[[INSERTED_SLICE_8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT:} diff --git a/mlir/test/Dialect/Linalg/winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir new file mode 100644 index 0000000000000..6cca3c602d4c0 --- /dev/null +++ b/mlir/test/Dialect/Linalg/winograd-conv2d.mlir @@ -0,0 +1,248 @@ +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-winograd-conv2d | FileCheck %s + +func.func @conv2d_4x4_3x3(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x4_3x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_2x2_5x5(%arg0: tensor<2x6x6x5xf32>, %arg1: tensor<2x5x5x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x2x2x2xf32> { + %0 = tensor.empty() : tensor<2x2x2x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x2x2x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x2x2x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x6x5xf32>, tensor<2x5x5x5xf32>) outs(%1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + return %2 : tensor<2x2x2x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_2x2_5x5 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x6x5xf32>, %[[ARG1:.*]]: tensor<2x5x5x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x2x2x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x2x2x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x2x2x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(2) r(5) ins(%[[ARG1]] : tensor<2x5x5x5xf32>) outs(%[[S2]] : tensor<1x1x6x6x5x2xf32>) -> tensor<1x1x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(2) r(5) ins(%[[ARG0]] : tensor<2x6x6x5xf32>) outs(%[[S4]] : tensor<1x1x6x6x2x5xf32>) -> tensor<1x1x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x5x2xf32> into tensor<36x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x6x2x5xf32> into tensor<36x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<36x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<36x2x5xf32>, tensor<36x5x2xf32>) outs(%[[S6]] : tensor<36x2x2xf32>) -> tensor<36x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 6, 2, 2] : tensor<36x2x2xf32> into tensor<1x1x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(2) r(5) ins(%[[EXPANDED]] : tensor<1x1x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x2x2x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_1x4_1x3(%arg0: tensor<2x1x6x5xf32>, %arg1: tensor<2x1x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x1x4x2xf32> { + %0 = tensor.empty() : tensor<2x1x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x1x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x1x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x1x6x5xf32>, tensor<2x1x3x5xf32>) outs(%1 : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> + return %2 : tensor<2x1x4x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_1x4_1x3 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x1x6x5xf32>, %[[ARG1:.*]]: tensor<2x1x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x1x4x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x1x4x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x1x4x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x1x3x5xf32>) outs(%[[S2]] : tensor<1x1x1x6x5x2xf32>) -> tensor<1x1x1x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x1x6x5xf32>) outs(%[[S4]] : tensor<1x1x1x6x2x5xf32>) -> tensor<1x1x1x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x1x6x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 1, 6, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x1x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x1x6x2x2xf32>) outs(%[[S1]] : tensor<2x1x4x2xf32>) -> tensor<2x1x4x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x1x4x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_4x1_3x1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x1x2xf32> { + %0 = tensor.empty() : tensor<2x4x1x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x1x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x1x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x1x5xf32>, tensor<2x3x1x5xf32>) outs(%1 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %2 : tensor<2x4x1x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_4x1_3x1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x4x1x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x4x1x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x4x1x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x1x5xf32>) outs(%[[S2]] : tensor<1x1x6x1x5x2xf32>) -> tensor<1x1x6x1x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x6x1x5xf32>) outs(%[[S4]] : tensor<1x1x6x1x2x5xf32>) -> tensor<1x1x6x1x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x5x2xf32> into tensor<6x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<1x1x6x1x2x5xf32> into tensor<6x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<6x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%[[S6]] : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [1, 1, 6, 1, 2, 2] : tensor<6x2x2xf32> into tensor<1x1x6x1x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<1x1x6x1x2x2xf32>) outs(%[[S1]] : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x4x1x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_aligned(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<2x8x8x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x8x8x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %2 : tensor<2x8x8x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_aligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x8x8x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> +// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> { + %0 = tensor.empty() : tensor<2x9x9x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x9x9x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %2 : tensor<2x9x9x2xf32> +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) { +// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[IN]] : f32 +// CHECK-NEXT: } -> tensor<2x9x9x2xf32> +// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32> +// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32> +// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32> +// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32> +// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32> +// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32> +// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32> +// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32> +// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> +// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> +// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32> +// CHECK-NEXT: } + +// ----- + +func.func @conv2d_unsupported_1(%arg0: tensor<2x6x5x5xf32>, %arg1: tensor<2x3x2x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x6x5x5xf32>, tensor<2x3x2x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_1 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_2(%arg0: tensor<2x7x7x5xf32>, %arg1: tensor<2x4x4x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x4x4x2xf32> { + %0 = tensor.empty() : tensor<2x4x4x2xf32> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x4x4x2xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<2x4x4x2xf32> + %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x7x7x5xf32>, tensor<2x4x4x5xf32>) outs(%1 : tensor<2x4x4x2xf32>) -> tensor<2x4x4x2xf32> + return %2 : tensor<2x4x4x2xf32> +} + +// CHECK-LABEL: conv2d_unsupported_2 +// CHECK: linalg.conv_2d_nhwc_fhwc + +// ----- + +func.func @conv2d_unsupported_3(%arg0: tensor, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor) -> tensor { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<2x3x3x5xf32>) outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: conv2d_unsupported_3 +// CHECK: linalg.conv_2d_nhwc_fhwc diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 4892fa2f99a7c..5899f56da7345 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -123,6 +123,13 @@ struct TestLinalgTransforms *this, "test-erase-unnecessary-inputs", llvm::cl::desc("Test patterns to erase unnecessary inputs"), llvm::cl::init(false)}; + Option testWinogradConv2D{ + *this, "test-winograd-conv2d", + llvm::cl::desc("Test transform conv2d by Winograd conv2d algorithm"), + llvm::cl::init(false)}; + Option testDecomposeWinogradOps{ + *this, "test-decompose-winograd-ops", + llvm::cl::desc("Test decompose Winograd ops"), llvm::cl::init(false)}; }; } // namespace @@ -207,6 +214,19 @@ static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyWinogradConv2D(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); + populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + +static void applyDecomposeWinogradOps(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateDecomposeWinogradOpsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -231,6 +251,10 @@ void TestLinalgTransforms::runOnOperation() { return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); if (testEraseUnnecessaryInputs) return applyEraseUnnecessaryInputs(getOperation()); + if (testWinogradConv2D) + return applyWinogradConv2D(getOperation()); + if (testDecomposeWinogradOps) + return applyDecomposeWinogradOps(getOperation()); } namespace mlir {