diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td index 5910aa3f7f2da..f5cf3dad75d9c 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td @@ -198,7 +198,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface auto memrefType = llvm::dyn_cast(getSourceType()); assert(memrefType && "Incorrect use of getStaticStrides"); - auto [strides, offset] = getStridesAndOffset(memrefType); + auto [strides, offset] = memrefType.getStridesAndOffset(); // reuse the storage of ConstStridesAttr since strides from // memref is not persistant setConstStrides(strides); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 19c5361124aac..df1e02732617d 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -409,33 +409,6 @@ inline bool TensorType::classof(Type type) { // Type Utilities //===----------------------------------------------------------------------===// -/// Returns the strides of the MemRef if the layout map is in strided form. -/// MemRefs with a layout map in strided form include: -/// 1. empty or identity layout map, in which case the stride information is -/// the canonical form computed from sizes; -/// 2. a StridedLayoutAttr layout; -/// 3. any other layout that be converted into a single affine map layout of -/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or -/// symbols. -/// -/// A stride specification is a list of integer values that are either static -/// or dynamic (encoded with ShapedType::kDynamic). Strides encode -/// the distance in the number of elements between successive entries along a -/// particular dimension. -LogicalResult getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - int64_t &offset); - -/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl, -/// int64_t) that will assert if the logical result is not succeeded. -std::pair, int64_t> getStridesAndOffset(MemRefType t); - -/// Return a version of `t` with identity layout if it can be determined -/// statically that the layout is the canonical contiguous strided layout. -/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of -/// `t` with simplified layout. -MemRefType canonicalizeStridedLayout(MemRefType t); - /// Given MemRef `sizes` that are either static or dynamic, returns the /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and /// once a dynamic dimension is encountered, all canonical strides become @@ -458,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)} AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef sizes, MLIRContext *context); - -/// Return "true" if the layout for `t` is compatible with strided semantics. -bool isStrided(MemRefType t); - -/// Return "true" if the last dimension of the given type has a static unit -/// stride. Also return "true" for types with no strides. -bool isLastMemrefDimUnitStride(MemRefType type); - -/// Return "true" if the last N dimensions of the given type are contiguous. -/// -/// Examples: -/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when -/// considering both _all_ and _only_ the trailing 3 dims, -/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when -/// considering the trailing 3 dims. -/// -bool trailingNDimsContiguous(MemRefType type, int64_t n); - } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 4f09d2e41e7ce..e5a2ae81da0c9 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -808,10 +808,52 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ /// Arguments that are passed into the builder must outlive the builder. class Builder; + /// Return "true" if the last N dimensions are contiguous. + /// + /// Examples: + /// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when + /// considering both _all_ and _only_ the trailing 3 dims, + /// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when + /// considering the trailing 3 dims. + /// + bool areTrailingDimsContiguous(int64_t n); + + /// Return a version of this type with identity layout if it can be + /// determined statically that the layout is the canonical contiguous + /// strided layout. Otherwise pass the layout into `simplifyAffineMap` + /// and return a copy of this type with simplified layout. + MemRefType canonicalizeStridedLayout(); + /// [deprecated] Returns the memory space in old raw integer representation. /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + /// Returns the strides of the MemRef if the layout map is in strided form. + /// MemRefs with a layout map in strided form include: + /// 1. empty or identity layout map, in which case the stride information + /// is the canonical form computed from sizes; + /// 2. a StridedLayoutAttr layout; + /// 3. any other layout that be converted into a single affine map layout + /// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are + /// constants or symbols. + /// + /// A stride specification is a list of integer values that are either + /// static or dynamic (encoded with ShapedType::kDynamic). Strides encode + /// the distance in the number of elements between successive entries along + /// a particular dimension. + LogicalResult getStridesAndOffset(SmallVectorImpl &strides, + int64_t &offset); + + /// Wrapper around getStridesAndOffset(SmallVectorImpl, int64_t) + /// that will assert if the logical result is not succeeded. + std::pair, int64_t> getStridesAndOffset(); + + /// Return "true" if the layout is compatible with strided semantics. + bool isStrided(); + + /// Return "true" if the last dimension has a static unit stride. Also + /// return "true" for types with no strides. + bool isLastDimUnitStride(); }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index e752cdfb47fbb..5ec995b3ae977 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -820,7 +820,7 @@ class StaticShapeMemRefOf allowedTypes> : def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; // For a MemRefType, verify that it has strides. -def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>; +def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>; class StridedMemRefOf allowedTypes> : ConfinedType, [HasStridesPred], diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 98ca9c3d23909..a080adf0f8103 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -524,7 +524,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *offset) { MemRefType memrefType = llvm::cast(unwrap(type)); SmallVector strides_; - if (failed(getStridesAndOffset(memrefType, strides_, *offset))) + if (failed(memrefType.getStridesAndOffset(strides_, *offset))) return mlirLogicalResultFailure(); (void)std::copy(strides_.begin(), strides_.end(), strides); diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 5d09d6f1d6952..51f5d7a161b90 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -192,7 +192,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { // Construct buffer descriptor from memref, attributes int64_t offset = 0; SmallVector strides; - if (failed(getStridesAndOffset(memrefType, strides, offset))) + if (failed(memrefType.getStridesAndOffset(strides, offset))) return gpuOp.emitOpError("Can't lower non-stride-offset memrefs"); MemRefDescriptor memrefDescriptor(memref); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index 19c3ba1f95020..63f99eb744a83 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -52,7 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape( assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. - auto [strides, offset] = getStridesAndOffset(type); + auto [strides, offset] = type.getStridesAndOffset(); assert(!ShapedType::isDynamic(offset) && "expected static offset"); assert(!llvm::any_of(strides, ShapedType::isDynamic) && "expected static strides"); @@ -193,7 +193,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc, MemRefType type) { // When we convert to LLVM, the input memref must have been normalized // beforehand. Hence, this call is guaranteed to work. - auto [strides, offsetCst] = getStridesAndOffset(type); + auto [strides, offsetCst] = type.getStridesAndOffset(); Value ptr = alignedPtr(builder, loc); // For zero offsets, we already have the base pointer. diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index d551506485a45..a47a2872ceb07 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -62,7 +62,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - auto [strides, offset] = getStridesAndOffset(type); + auto [strides, offset] = type.getStridesAndOffset(); MemRefDescriptor memRefDescriptor(memRefDesc); // Use a canonical representation of the start address so that later diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 247a8ab28a44b..ea251e4564ea8 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -485,7 +485,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const { SmallVector LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) const { - if (!isStrided(type)) { + if (!type.isStrided()) { emitError( UnknownLoc::get(type.getContext()), "conversion to strided form failed either due to non-strided layout " @@ -603,7 +603,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) { int64_t offset = 0; SmallVector strides; - if (failed(getStridesAndOffset(memrefTy, strides, offset))) + if (failed(memrefTy.getStridesAndOffset(strides, offset))) return false; for (int64_t stride : strides) diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 86f687d7f2636..f7542b8b3bc5c 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1136,7 +1136,7 @@ struct MemRefReshapeOpLowering // Extract the offset and strides from the type. int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(targetMemRefType, strides, offset))) + if (failed(targetMemRefType.getStridesAndOffset(strides, offset))) return rewriter.notifyMatchFailure( reshapeOp, "failed to get stride and offset exprs"); @@ -1451,7 +1451,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { int64_t offset; SmallVector strides; - auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset); if (failed(successStrides)) return viewOp.emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); @@ -1560,7 +1560,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { auto memRefType = atomicOp.getMemRefType(); SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(memRefType, strides, offset))) + if (failed(memRefType.getStridesAndOffset(strides, offset))) return failure(); auto dataPtr = getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(), diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 5b4414d67fdac..eaefe9e385793 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -132,7 +132,7 @@ static std::optional getStaticallyKnownRowStride(ShapedType type) { return 0; int64_t offset = 0; SmallVector strides; - if (failed(getStridesAndOffset(memrefType, strides, offset)) || + if (failed(memrefType.getStridesAndOffset(strides, offset)) || strides.back() != 1) return std::nullopt; int64_t stride = strides[strides.size() - 2]; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index d688d8e2ab658..a1e21cb524bd9 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -91,7 +91,7 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter, // Check if the last stride is non-unit and has a valid memory space. static LogicalResult isMemRefTypeSupported(MemRefType memRefType, const LLVMTypeConverter &converter) { - if (!isLastMemrefDimUnitStride(memRefType)) + if (!memRefType.isLastDimUnitStride()) return failure(); if (failed(converter.getMemRefAddressSpace(memRefType))) return failure(); @@ -1374,7 +1374,7 @@ static std::optional> computeContiguousStrides(MemRefType memRefType) { int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(memRefType, strides, offset))) + if (failed(memRefType.getStridesAndOffset(strides, offset))) return std::nullopt; if (!strides.empty() && strides.back() != 1) return std::nullopt; diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 01bc65c841e94..22bf27d229ce5 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1650,7 +1650,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern { return failure(); if (xferOp.getVectorType().getRank() != 1) return failure(); - if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) + if (map.isMinorIdentity() && memRefType.isLastDimUnitStride()) return failure(); // Handled by ConvertVectorToLLVM // Loop bounds, step, state... diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 8041bdf7da19b..d3229d2e91296 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -76,8 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, // Validate further transfer op semantics. SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(srcTy, strides, offset)) || - strides.back() != 1) + if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1) return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); @@ -105,7 +104,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, xegpu::TensorDescType descType, TypedValue src, Operation::operand_range offsets) { MemRefType srcTy = src.getType(); - auto [strides, offset] = getStridesAndOffset(srcTy); + auto [strides, offset] = srcTy.getStridesAndOffset(); xegpu::CreateNdDescOp ndDesc; if (srcTy.hasStaticShape()) { diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 5af0cb0c7ba1c..271ca382e2f0b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -129,7 +129,7 @@ static bool staticallyOutOfBounds(OpType op) { return false; int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(bufferType, strides, offset))) + if (failed(bufferType.getStridesAndOffset(strides, offset))) return false; int64_t result = offset + op.getIndexOffset().value_or(0); if (op.getSgprOffset()) { diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp index 4eac371d4c1ae..4cb777b03b196 100644 --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -53,8 +53,7 @@ FailureOr getStride(ConversionPatternRewriter &rewriter, unsigned bytes = width >> 3; int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(mType, strides, offset)) || - strides.back() != 1) + if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1) return failure(); if (strides[preLast] == ShapedType::kDynamic) { // Dynamic stride needs code to compute the stride at runtime. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index f1841b860ff81..6be55a1d28224 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -42,8 +42,8 @@ FailureOr mlir::bufferization::castOrReallocMemRefValue( auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { int64_t sourceOffset, targetOffset; SmallVector sourceStrides, targetStrides; - if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || - failed(getStridesAndOffset(target, targetStrides, targetOffset))) + if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) || + failed(target.getStridesAndOffset(targetStrides, targetOffset))) return false; auto dynamicToStatic = [](int64_t a, int64_t b) { return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 2502744cb3f58..ce0f112dc2dd2 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -29,7 +29,7 @@ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn; static bool hasFullyDynamicLayoutMap(MemRefType type) { int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(type, strides, offset))) + if (failed(type.getStridesAndOffset(strides, offset))) return false; if (!llvm::all_of(strides, ShapedType::isDynamic)) return false; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 49209229259a7..301066e7d3e1f 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1903,7 +1903,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() { auto operand = resMatrixType.getOperand(); auto srcMemrefType = llvm::cast(srcType); - if (!isLastMemrefDimUnitStride(srcMemrefType)) + if (!srcMemrefType.isLastDimUnitStride()) return emitError( "expected source memref most minor dim must have unit stride"); @@ -1923,7 +1923,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() { auto srcMatrixType = llvm::cast(srcType); auto dstMemrefType = llvm::cast(dstType); - if (!isLastMemrefDimUnitStride(dstMemrefType)) + if (!dstMemrefType.isLastDimUnitStride()) return emitError( "expected destination memref most minor dim must have unit stride"); diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index a504101fb3f2f..2afdeff3a7be1 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -67,7 +67,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, rewriter.create(loc, source); } - auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); + auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 9aae46a5c288d..4f75b7618d636 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -163,7 +163,7 @@ static SmallVector getConstantOffset(MemRefType memrefType) { SmallVector strides; int64_t offset; LogicalResult hasStaticInformation = - getStridesAndOffset(memrefType, strides, offset); + memrefType.getStridesAndOffset(strides, offset); if (failed(hasStaticInformation)) return SmallVector(); return SmallVector(1, offset); @@ -176,7 +176,7 @@ static SmallVector getConstantStrides(MemRefType memrefType) { SmallVector strides; int64_t offset; LogicalResult hasStaticInformation = - getStridesAndOffset(memrefType, strides, offset); + memrefType.getStridesAndOffset(strides, offset); if (failed(hasStaticInformation)) return SmallVector(); return strides; @@ -663,8 +663,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) { // Only fold casts between strided memref forms. int64_t sourceOffset, resultOffset; SmallVector sourceStrides, resultStrides; - if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || - failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) || + failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) return false; // If cast is towards more static sizes along any dimension, don't fold. @@ -708,8 +708,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (aT.getLayout() != bT.getLayout()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; - if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || - failed(getStridesAndOffset(bT, bStrides, bOffset)) || + if (failed(aT.getStridesAndOffset(aStrides, aOffset)) || + failed(bT.getStridesAndOffset(bStrides, bOffset)) || aStrides.size() != bStrides.size()) return false; @@ -954,9 +954,9 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, SmallVector originalStrides, candidateStrides; int64_t originalOffset, candidateOffset; if (failed( - getStridesAndOffset(originalType, originalStrides, originalOffset)) || + originalType.getStridesAndOffset(originalStrides, originalOffset)) || failed( - getStridesAndOffset(reducedType, candidateStrides, candidateOffset))) + reducedType.getStridesAndOffset(candidateStrides, candidateOffset))) return failure(); // For memrefs, a dimension is truly dropped if its corresponding stride is @@ -1903,7 +1903,7 @@ LogicalResult ReinterpretCastOp::verify() { // identity layout. int64_t resultOffset; SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) return emitError("expected result type to have strided layout but found ") << resultType; @@ -2223,7 +2223,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef resultShape, ArrayRef reassociation) { int64_t srcOffset; SmallVector srcStrides; - if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) return failure(); assert(srcStrides.size() == reassociation.size() && "invalid reassociation"); @@ -2420,7 +2420,7 @@ computeCollapsedLayoutMap(MemRefType srcType, int64_t srcOffset; SmallVector srcStrides; auto srcShape = srcType.getShape(); - if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) return failure(); // The result stride of a reassociation group is the stride of the last entry @@ -2706,7 +2706,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType, assert(staticStrides.size() == rank && "staticStrides length mismatch"); // Extract source offset and strides. - auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType); + auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset(); // Compute target offset whose value is: // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. @@ -2912,8 +2912,8 @@ Value SubViewOp::getViewSource() { return getSource(); } static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { int64_t t1Offset, t2Offset; SmallVector t1Strides, t2Strides; - auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); - auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); + auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset); + auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset); return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset; } @@ -2928,8 +2928,8 @@ static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, "incorrect number of dropped dims"); int64_t t1Offset, t2Offset; SmallVector t1Strides, t2Strides; - auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); - auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); + auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset); + auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset); if (failed(res1) || failed(res2)) return false; for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) { @@ -2980,7 +2980,7 @@ LogicalResult SubViewOp::verify() { << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. - if (!isStrided(baseType)) + if (!baseType.isStrided()) return emitError("base type ") << baseType << " is not strided"; // Compute the expected result type, assuming that there are no rank @@ -3261,7 +3261,7 @@ struct SubViewReturnTypeCanonicalizer { return nonReducedType; // Take the strides and offset from the non-rank reduced type. - auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType); + auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset(); // Drop dims from shape and strides. SmallVector targetShape; @@ -3341,7 +3341,7 @@ void TransposeOp::getAsmResultNames( static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap) { auto originalSizes = memRefType.getShape(); - auto [originalStrides, offset] = getStridesAndOffset(memRefType); + auto [originalStrides, offset] = memRefType.getStridesAndOffset(); assert(originalStrides.size() == static_cast(memRefType.getRank())); // Compute permuted sizes and strides. @@ -3400,10 +3400,10 @@ LogicalResult TransposeOp::verify() { auto srcType = llvm::cast(getIn().getType()); auto resultType = llvm::cast(getType()); - auto canonicalResultType = canonicalizeStridedLayout( - inferTransposeResultType(srcType, getPermutation())); + auto canonicalResultType = inferTransposeResultType(srcType, getPermutation()) + .canonicalizeStridedLayout(); - if (canonicalizeStridedLayout(resultType) != canonicalResultType) + if (resultType.canonicalizeStridedLayout() != canonicalResultType) return emitOpError("result type ") << resultType << " is not equivalent to the canonical transposed input type " @@ -3483,7 +3483,7 @@ struct ViewOpShapeFolder : public OpRewritePattern { // Get offset from old memref view type 'memRefType'. int64_t oldOffset; SmallVector oldStrides; - if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset))) return failure(); assert(oldOffset == 0 && "Expected 0 offset"); diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 28f9061d9873b..f58385a7777db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -632,7 +632,7 @@ void memref::populateMemRefNarrowTypeEmulationConversions( // Currently only handle innermost stride being 1, checking SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(ty, strides, offset))) + if (failed(ty.getStridesAndOffset(strides, offset))) return nullptr; if (!strides.empty() && strides.back() != 1) return nullptr; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index aa008f8407b5d..b69cbabe0dde9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -68,9 +68,9 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, auto newExtractStridedMetadata = rewriter.create(origLoc, source); - auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); + auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); #ifndef NDEBUG - auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType()); + auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset(); #endif // NDEBUG // Compute the new strides and offset from the base strides and offset: @@ -363,7 +363,7 @@ SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, // Collect the statically known information about the original stride. Value source = expandShape.getSrc(); auto sourceType = cast(source.getType()); - auto [strides, offset] = getStridesAndOffset(sourceType); + auto [strides, offset] = sourceType.getStridesAndOffset(); OpFoldResult origStride = ShapedType::isDynamic(strides[groupId]) ? origStrides[groupId] @@ -503,7 +503,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, Value source = collapseShape.getSrc(); auto sourceType = cast(source.getType()); - auto [strides, offset] = getStridesAndOffset(sourceType); + auto [strides, offset] = sourceType.getStridesAndOffset(); SmallVector groupStrides; ArrayRef srcShape = sourceType.getShape(); @@ -528,7 +528,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, // but we still have to make the type system happy. MemRefType collapsedType = collapseShape.getResultType(); auto [collapsedStrides, collapsedOffset] = - getStridesAndOffset(collapsedType); + collapsedType.getStridesAndOffset(); int64_t finalStride = collapsedStrides[groupId]; if (ShapedType::isDynamic(finalStride)) { // Look for a dynamic stride. At this point we don't know which one is @@ -581,7 +581,7 @@ static FailureOr resolveReshapeStridedMetadata( rewriter.create(origLoc, source); // Collect statically known information. - auto [strides, offset] = getStridesAndOffset(sourceType); + auto [strides, offset] = sourceType.getStridesAndOffset(); MemRefType reshapeType = reshape.getResultType(); unsigned reshapeRank = reshapeType.getRank(); @@ -1068,7 +1068,7 @@ class ExtractStridedMetadataOpCastFolder : ofr; }; - auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType); + auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); assert(sourceStrides.size() == rank && "unexpected number of strides"); // Register the new offset. diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 450bfa0cec0c7..f93ae0a7a298f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -91,7 +91,7 @@ struct CastOpInterface // Get result offset and strides. int64_t resultOffset; SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset))) return; // Check offset. diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 6de744a7f7524..270b43100a3a7 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -27,7 +27,7 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) { SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(type, strides, offset))) + if (failed(type.getStridesAndOffset(strides, offset))) return false; // MemRef is contiguous if outer dimensions are size-1 and inner diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 47d1b8492e06e..ba86e8d6ceaf9 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -70,9 +70,9 @@ LogicalResult DeviceAsyncCopyOp::verify() { auto srcMemref = llvm::cast(getSrc().getType()); auto dstMemref = llvm::cast(getDst().getType()); - if (!isLastMemrefDimUnitStride(srcMemref)) + if (!srcMemref.isLastDimUnitStride()) return emitError("source memref most minor dim must have unit stride"); - if (!isLastMemrefDimUnitStride(dstMemref)) + if (!dstMemref.isLastDimUnitStride()) return emitError("destination memref most minor dim must have unit stride"); if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref)) return emitError() diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp index f8c699c65fe49..10bc1993ffd96 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -24,8 +24,8 @@ template static bool isContiguousXferOp(OpTy op) { return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) && op.hasPureBufferSemantics() && - isLastMemrefDimUnitStride( - cast(nvgpu::getMemrefOperand(op).getType())); + cast(nvgpu::getMemrefOperand(op).getType()) + .isLastDimUnitStride(); } /// Return "true" if the given op is a contiguous and suitable diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index c500815857ca5..39cca7d363e0d 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -296,7 +296,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) { // Check that the last dimension of the read is contiguous. Note that it is // possible to expand support for this by scalarizing all the loads during // conversion. - auto [strides, offset] = mlir::getStridesAndOffset(sourceType); + auto [strides, offset] = sourceType.getStridesAndOffset(); return strides.back() == 1; } @@ -320,6 +320,6 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) { // Check that the last dimension of the target memref is contiguous. Note that // it is possible to expand support for this by scalarizing all the stores // during conversion. - auto [strides, offset] = mlir::getStridesAndOffset(sourceType); + auto [strides, offset] = sourceType.getStridesAndOffset(); return strides.back() == 1; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 29f7e8afe0773..c56dbcca2175d 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -206,7 +206,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { int64_t offset; SmallVector strides; if (!memRefType.hasStaticShape() || - failed(getStridesAndOffset(memRefType, strides, offset))) + failed(memRefType.getStridesAndOffset(strides, offset))) return std::nullopt; // To get the size of the memref object in memory, the total size is the @@ -1225,7 +1225,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(baseType, strides, offset)) || + if (failed(baseType.getStridesAndOffset(strides, offset)) || llvm::is_contained(strides, ShapedType::kDynamic) || ShapedType::isDynamic(offset)) { return nullptr; @@ -1256,7 +1256,7 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(baseType, strides, offset)) || + if (failed(baseType.getStridesAndOffset(strides, offset)) || llvm::is_contained(strides, ShapedType::kDynamic) || ShapedType::isDynamic(offset)) { return nullptr; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 1abcacd6d6db3..ed3ba321b37ab 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -186,7 +186,7 @@ struct CollapseShapeOpInterface // the source type. SmallVector strides; int64_t offset; - if (failed(getStridesAndOffset(bufferType, strides, offset))) + if (failed(bufferType.getStridesAndOffset(strides, offset))) return failure(); resultType = MemRefType::get( {}, tensorResultType.getElementType(), diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 696d1e0f9b1e6..d8fc881911bae 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4974,7 +4974,7 @@ static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, (vecTy.getRank() == 0 || vecTy.getNumElements() == 1)) return success(); - if (!isLastMemrefDimUnitStride(memRefTy)) + if (!memRefTy.isLastDimUnitStride()) return op->emitOpError("most minor memref dim must have unit stride"); return success(); } @@ -5789,7 +5789,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result, } LogicalResult TypeCastOp::verify() { - MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); + MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout(); if (!canonicalType.getLayout().isIdentity()) return emitOpError("expects operand to be a memref with identity layout"); if (!getResultMemRefType().getLayout().isIdentity()) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index f9428a4ce2864..314dc44134e04 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -435,7 +435,7 @@ struct TransferReadToVectorLoadLowering return rewriter.notifyMatchFailure(read, "not a memref source"); // Non-unit strides are handled by VectorToSCF. - if (!isLastMemrefDimUnitStride(memRefType)) + if (!memRefType.isLastDimUnitStride()) return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF"); // If there is broadcasting involved then we first load the unbroadcasted @@ -588,7 +588,7 @@ struct TransferWriteToVectorStoreLowering }); // Non-unit strides are handled by VectorToSCF. - if (!isLastMemrefDimUnitStride(memRefType)) + if (!memRefType.isLastDimUnitStride()) return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { diag << "most minor stride is not 1: " << write; }); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index b0892d16969d2..5871d6dd5b3e6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -267,7 +267,7 @@ static MemRefType dropUnitDims(MemRefType inputType, auto targetShape = getReducedShape(sizes); Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( targetShape, inputType, offsets, sizes, strides); - return canonicalizeStridedLayout(cast(rankReducedType)); + return cast(rankReducedType).canonicalizeStridedLayout(); } /// Creates a rank-reducing memref.subview op that drops unit dims from its @@ -283,8 +283,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, rewriter.getIndexAttr(1)); MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides); - if (canonicalizeStridedLayout(resultType) == - canonicalizeStridedLayout(inputType)) + if (resultType.canonicalizeStridedLayout() == + inputType.canonicalizeStridedLayout()) return input; return rewriter.create(loc, resultType, input, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index ee622e886f618..66c23dd6e7495 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -145,8 +145,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { return MemRefType(); int64_t aOffset, bOffset; SmallVector aStrides, bStrides; - if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || - failed(getStridesAndOffset(bT, bStrides, bOffset)) || + if (failed(aT.getStridesAndOffset(aStrides, aOffset)) || + failed(bT.getStridesAndOffset(bStrides, bOffset)) || aStrides.size() != bStrides.size()) return MemRefType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 21ec718efd6a7..84c1deaebcd00 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1243,7 +1243,7 @@ static FailureOr getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { SmallVector srcStrides; int64_t srcOffset; - if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset))) + if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset))) return failure(); auto isUnitDim = [](VectorType type, int dim) { diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index e590d8c43c44b..7b56cd0cf0e91 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -261,7 +261,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { ArrayRef vectorShape = vectorType.getShape(); auto vecRank = vectorType.getRank(); - if (!trailingNDimsContiguous(memrefType, vecRank)) + if (!memrefType.areTrailingDimsContiguous(vecRank)) return false; // Extract the trailing dims and strides of the input memref diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index bd1163bddf7ee..3924d082f0628 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -645,24 +645,74 @@ LogicalResult MemRefType::verify(function_ref emitError, return success(); } -//===----------------------------------------------------------------------===// -// UnrankedMemRefType -//===----------------------------------------------------------------------===// +bool MemRefType::areTrailingDimsContiguous(int64_t n) { + if (!isLastDimUnitStride()) + return false; -unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { - return detail::getMemorySpaceAsInt(getMemorySpace()); + auto memrefShape = getShape().take_back(n); + if (ShapedType::isDynamicShape(memrefShape)) + return false; + + if (getLayout().isIdentity()) + return true; + + int64_t offset; + SmallVector stridesFull; + if (!succeeded(getStridesAndOffset(stridesFull, offset))) + return false; + auto strides = ArrayRef(stridesFull).take_back(n); + + if (strides.empty()) + return true; + + // Check whether strides match "flattened" dims. + SmallVector flattenedDims; + auto dimProduct = 1; + for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { + dimProduct *= dim; + flattenedDims.push_back(dimProduct); + } + + strides = strides.drop_back(1); + return llvm::equal(strides, llvm::reverse(flattenedDims)); } -LogicalResult -UnrankedMemRefType::verify(function_ref emitError, - Type elementType, Attribute memorySpace) { - if (!BaseMemRefType::isValidElementType(elementType)) - return emitError() << "invalid memref element type"; +MemRefType MemRefType::canonicalizeStridedLayout() { + AffineMap m = getLayout().getAffineMap(); - if (!isSupportedMemorySpace(memorySpace)) - return emitError() << "unsupported memory space Attribute"; + // Already in canonical form. + if (m.isIdentity()) + return *this; - return success(); + // Can't reduce to canonical identity form, return in canonical form. + if (m.getNumResults() > 1) + return *this; + + // Corner-case for 0-D affine maps. + if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { + if (auto cst = llvm::dyn_cast(m.getResult(0))) + if (cst.getValue() == 0) + return MemRefType::Builder(*this).setLayout({}); + return *this; + } + + // 0-D corner case for empty shape that still have an affine map. Example: + // `memref (s0)>>`. This is a 1 element memref whose + // offset needs to remain, just return t. + if (getShape().empty()) + return *this; + + // If the canonical strided layout for the sizes of `t` is equal to the + // simplified layout of `t` we can just return an empty layout. Otherwise, + // just simplify the existing layout. + AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext()); + auto simplifiedLayoutExpr = + simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); + if (expr != simplifiedLayoutExpr) + return MemRefType::Builder(*this).setLayout( + AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(), + simplifiedLayoutExpr))); + return MemRefType::Builder(*this).setLayout({}); } // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( @@ -783,11 +833,10 @@ static LogicalResult getStridesAndOffset(MemRefType t, return success(); } -LogicalResult mlir::getStridesAndOffset(MemRefType t, - SmallVectorImpl &strides, - int64_t &offset) { +LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl &strides, + int64_t &offset) { // Happy path: the type uses the strided layout directly. - if (auto strided = llvm::dyn_cast(t.getLayout())) { + if (auto strided = llvm::dyn_cast(getLayout())) { llvm::append_range(strides, strided.getStrides()); offset = strided.getOffset(); return success(); @@ -797,14 +846,14 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t, // convertible to affine maps. AffineExpr offsetExpr; SmallVector strideExprs; - if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) + if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr))) return failure(); - if (auto cst = dyn_cast(offsetExpr)) + if (auto cst = llvm::dyn_cast(offsetExpr)) offset = cst.getValue(); else offset = ShapedType::kDynamic; for (auto e : strideExprs) { - if (auto c = dyn_cast(e)) + if (auto c = llvm::dyn_cast(e)) strides.push_back(c.getValue()); else strides.push_back(ShapedType::kDynamic); @@ -812,16 +861,49 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t, return success(); } -std::pair, int64_t> -mlir::getStridesAndOffset(MemRefType t) { +std::pair, int64_t> MemRefType::getStridesAndOffset() { SmallVector strides; int64_t offset; - LogicalResult status = getStridesAndOffset(t, strides, offset); + LogicalResult status = getStridesAndOffset(strides, offset); (void)status; assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset"); return {strides, offset}; } +bool MemRefType::isStrided() { + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(strides, offset); + return succeeded(res); +} + +bool MemRefType::isLastDimUnitStride() { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +} + +//===----------------------------------------------------------------------===// +// UnrankedMemRefType +//===----------------------------------------------------------------------===// + +unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { + return detail::getMemorySpaceAsInt(getMemorySpace()); +} + +LogicalResult +UnrankedMemRefType::verify(function_ref emitError, + Type elementType, Attribute memorySpace) { + if (!BaseMemRefType::isValidElementType(elementType)) + return emitError() << "invalid memref element type"; + + if (!isSupportedMemorySpace(memorySpace)) + return emitError() << "unsupported memory space Attribute"; + + return success(); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// @@ -849,49 +931,6 @@ size_t TupleType::size() const { return getImpl()->size(); } // Type Utilities //===----------------------------------------------------------------------===// -/// Return a version of `t` with identity layout if it can be determined -/// statically that the layout is the canonical contiguous strided layout. -/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of -/// `t` with simplified layout. -/// If `t` has multiple layout maps or a multi-result layout, just return `t`. -MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { - AffineMap m = t.getLayout().getAffineMap(); - - // Already in canonical form. - if (m.isIdentity()) - return t; - - // Can't reduce to canonical identity form, return in canonical form. - if (m.getNumResults() > 1) - return t; - - // Corner-case for 0-D affine maps. - if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { - if (auto cst = dyn_cast(m.getResult(0))) - if (cst.getValue() == 0) - return MemRefType::Builder(t).setLayout({}); - return t; - } - - // 0-D corner case for empty shape that still have an affine map. Example: - // `memref (s0)>>`. This is a 1 element memref whose - // offset needs to remain, just return t. - if (t.getShape().empty()) - return t; - - // If the canonical strided layout for the sizes of `t` is equal to the - // simplified layout of `t` we can just return an empty layout. Otherwise, - // just simplify the existing layout. - AffineExpr expr = - makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); - auto simplifiedLayoutExpr = - simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); - if (expr != simplifiedLayoutExpr) - return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( - m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); - return MemRefType::Builder(t).setLayout({}); -} - AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { @@ -932,49 +971,3 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, exprs.push_back(getAffineDimExpr(dim, context)); return makeCanonicalStridedLayoutExpr(sizes, exprs, context); } - -bool mlir::isStrided(MemRefType t) { - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(t, strides, offset); - return succeeded(res); -} - -bool mlir::isLastMemrefDimUnitStride(MemRefType type) { - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && (strides.empty() || strides.back() == 1); -} - -bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) { - if (!isLastMemrefDimUnitStride(type)) - return false; - - auto memrefShape = type.getShape().take_back(n); - if (ShapedType::isDynamicShape(memrefShape)) - return false; - - if (type.getLayout().isIdentity()) - return true; - - int64_t offset; - SmallVector stridesFull; - if (!succeeded(getStridesAndOffset(type, stridesFull, offset))) - return false; - auto strides = ArrayRef(stridesFull).take_back(n); - - if (strides.empty()) - return true; - - // Check whether strides match "flattened" dims. - SmallVector flattenedDims; - auto dimProduct = 1; - for (auto dim : llvm::reverse(memrefShape.drop_front(1))) { - dimProduct *= dim; - flattenedDims.push_back(dimProduct); - } - - strides = strides.drop_back(1); - return llvm::equal(strides, llvm::reverse(flattenedDims)); -} diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp index 968e10b8d0cab..f17f5db2fa22f 100644 --- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp @@ -35,7 +35,7 @@ void TestMemRefStrideCalculation::runOnOperation() { auto memrefType = cast(allocOp.getResult().getType()); int64_t offset; SmallVector strides; - if (failed(getStridesAndOffset(memrefType, strides, offset))) { + if (failed(memrefType.getStridesAndOffset(strides, offset))) { llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " << "strided form\n"; return;