Skip to content

Commit 6aaa8f2

Browse files
[mlir][IR][NFC] Move free-standing functions to MemRefType (#123465)
Turn free-standing `MemRefType`-related helper functions in `BuiltinTypes.h` into member functions.
1 parent 79231a8 commit 6aaa8f2

38 files changed

+228
-240
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
198198

199199
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
200200
assert(memrefType && "Incorrect use of getStaticStrides");
201-
auto [strides, offset] = getStridesAndOffset(memrefType);
201+
auto [strides, offset] = memrefType.getStridesAndOffset();
202202
// reuse the storage of ConstStridesAttr since strides from
203203
// memref is not persistant
204204
setConstStrides(strides);

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -409,33 +409,6 @@ inline bool TensorType::classof(Type type) {
409409
// Type Utilities
410410
//===----------------------------------------------------------------------===//
411411

412-
/// Returns the strides of the MemRef if the layout map is in strided form.
413-
/// MemRefs with a layout map in strided form include:
414-
/// 1. empty or identity layout map, in which case the stride information is
415-
/// the canonical form computed from sizes;
416-
/// 2. a StridedLayoutAttr layout;
417-
/// 3. any other layout that be converted into a single affine map layout of
418-
/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
419-
/// symbols.
420-
///
421-
/// A stride specification is a list of integer values that are either static
422-
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
423-
/// the distance in the number of elements between successive entries along a
424-
/// particular dimension.
425-
LogicalResult getStridesAndOffset(MemRefType t,
426-
SmallVectorImpl<int64_t> &strides,
427-
int64_t &offset);
428-
429-
/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
430-
/// int64_t) that will assert if the logical result is not succeeded.
431-
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
432-
433-
/// Return a version of `t` with identity layout if it can be determined
434-
/// statically that the layout is the canonical contiguous strided layout.
435-
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
436-
/// `t` with simplified layout.
437-
MemRefType canonicalizeStridedLayout(MemRefType t);
438-
439412
/// Given MemRef `sizes` that are either static or dynamic, returns the
440413
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
441414
/// once a dynamic dimension is encountered, all canonical strides become
@@ -458,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
458431
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
459432
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
460433
MLIRContext *context);
461-
462-
/// Return "true" if the layout for `t` is compatible with strided semantics.
463-
bool isStrided(MemRefType t);
464-
465-
/// Return "true" if the last dimension of the given type has a static unit
466-
/// stride. Also return "true" for types with no strides.
467-
bool isLastMemrefDimUnitStride(MemRefType type);
468-
469-
/// Return "true" if the last N dimensions of the given type are contiguous.
470-
///
471-
/// Examples:
472-
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
473-
/// considering both _all_ and _only_ the trailing 3 dims,
474-
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
475-
/// considering the trailing 3 dims.
476-
///
477-
bool trailingNDimsContiguous(MemRefType type, int64_t n);
478-
479434
} // namespace mlir
480435

481436
#endif // MLIR_IR_BUILTINTYPES_H

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,10 +808,52 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
808808
/// Arguments that are passed into the builder must outlive the builder.
809809
class Builder;
810810

811+
/// Return "true" if the last N dimensions are contiguous.
812+
///
813+
/// Examples:
814+
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
815+
/// considering both _all_ and _only_ the trailing 3 dims,
816+
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
817+
/// considering the trailing 3 dims.
818+
///
819+
bool areTrailingDimsContiguous(int64_t n);
820+
821+
/// Return a version of this type with identity layout if it can be
822+
/// determined statically that the layout is the canonical contiguous
823+
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
824+
/// and return a copy of this type with simplified layout.
825+
MemRefType canonicalizeStridedLayout();
826+
811827
/// [deprecated] Returns the memory space in old raw integer representation.
812828
/// New `Attribute getMemorySpace()` method should be used instead.
813829
unsigned getMemorySpaceAsInt() const;
814830

831+
/// Returns the strides of the MemRef if the layout map is in strided form.
832+
/// MemRefs with a layout map in strided form include:
833+
/// 1. empty or identity layout map, in which case the stride information
834+
/// is the canonical form computed from sizes;
835+
/// 2. a StridedLayoutAttr layout;
836+
/// 3. any other layout that be converted into a single affine map layout
837+
/// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are
838+
/// constants or symbols.
839+
///
840+
/// A stride specification is a list of integer values that are either
841+
/// static or dynamic (encoded with ShapedType::kDynamic). Strides encode
842+
/// the distance in the number of elements between successive entries along
843+
/// a particular dimension.
844+
LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
845+
int64_t &offset);
846+
847+
/// Wrapper around getStridesAndOffset(SmallVectorImpl<int64_t>, int64_t)
848+
/// that will assert if the logical result is not succeeded.
849+
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset();
850+
851+
/// Return "true" if the layout is compatible with strided semantics.
852+
bool isStrided();
853+
854+
/// Return "true" if the last dimension has a static unit stride. Also
855+
/// return "true" for types with no strides.
856+
bool isLastDimUnitStride();
815857
}];
816858
let skipDefaultBuilders = 1;
817859
let genVerifyDecl = 1;

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ class StaticShapeMemRefOf<list<Type> allowedTypes> :
820820
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
821821

822822
// For a MemRefType, verify that it has strides.
823-
def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>;
823+
def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>;
824824

825825
class StridedMemRefOf<list<Type> allowedTypes> :
826826
ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],

mlir/lib/CAPI/IR/BuiltinTypes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
524524
int64_t *offset) {
525525
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
526526
SmallVector<int64_t> strides_;
527-
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
527+
if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
528528
return mlirLogicalResultFailure();
529529

530530
(void)std::copy(strides_.begin(), strides_.end(), strides);

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
192192
// Construct buffer descriptor from memref, attributes
193193
int64_t offset = 0;
194194
SmallVector<int64_t, 5> strides;
195-
if (failed(getStridesAndOffset(memrefType, strides, offset)))
195+
if (failed(memrefType.getStridesAndOffset(strides, offset)))
196196
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
197197

198198
MemRefDescriptor memrefDescriptor(memref);

mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
5252
assert(type.hasStaticShape() && "unexpected dynamic shape");
5353

5454
// Extract all strides and offsets and verify they are static.
55-
auto [strides, offset] = getStridesAndOffset(type);
55+
auto [strides, offset] = type.getStridesAndOffset();
5656
assert(!ShapedType::isDynamic(offset) && "expected static offset");
5757
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
5858
"expected static strides");
@@ -193,7 +193,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
193193
MemRefType type) {
194194
// When we convert to LLVM, the input memref must have been normalized
195195
// beforehand. Hence, this call is guaranteed to work.
196-
auto [strides, offsetCst] = getStridesAndOffset(type);
196+
auto [strides, offsetCst] = type.getStridesAndOffset();
197197

198198
Value ptr = alignedPtr(builder, loc);
199199
// For zero offsets, we already have the base pointer.

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
6262
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
6363
ConversionPatternRewriter &rewriter) const {
6464

65-
auto [strides, offset] = getStridesAndOffset(type);
65+
auto [strides, offset] = type.getStridesAndOffset();
6666

6767
MemRefDescriptor memRefDescriptor(memRefDesc);
6868
// Use a canonical representation of the start address so that later

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
485485
SmallVector<Type, 5>
486486
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
487487
bool unpackAggregates) const {
488-
if (!isStrided(type)) {
488+
if (!type.isStrided()) {
489489
emitError(
490490
UnknownLoc::get(type.getContext()),
491491
"conversion to strided form failed either due to non-strided layout "
@@ -603,7 +603,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
603603

604604
int64_t offset = 0;
605605
SmallVector<int64_t, 4> strides;
606-
if (failed(getStridesAndOffset(memrefTy, strides, offset)))
606+
if (failed(memrefTy.getStridesAndOffset(strides, offset)))
607607
return false;
608608

609609
for (int64_t stride : strides)

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,7 +1136,7 @@ struct MemRefReshapeOpLowering
11361136
// Extract the offset and strides from the type.
11371137
int64_t offset;
11381138
SmallVector<int64_t> strides;
1139-
if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
1139+
if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
11401140
return rewriter.notifyMatchFailure(
11411141
reshapeOp, "failed to get stride and offset exprs");
11421142

@@ -1451,7 +1451,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
14511451

14521452
int64_t offset;
14531453
SmallVector<int64_t, 4> strides;
1454-
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
1454+
auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
14551455
if (failed(successStrides))
14561456
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
14571457
assert(offset == 0 && "expected offset to be 0");
@@ -1560,7 +1560,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
15601560
auto memRefType = atomicOp.getMemRefType();
15611561
SmallVector<int64_t> strides;
15621562
int64_t offset;
1563-
if (failed(getStridesAndOffset(memRefType, strides, offset)))
1563+
if (failed(memRefType.getStridesAndOffset(strides, offset)))
15641564
return failure();
15651565
auto dataPtr =
15661566
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
132132
return 0;
133133
int64_t offset = 0;
134134
SmallVector<int64_t, 2> strides;
135-
if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
135+
if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
136136
strides.back() != 1)
137137
return std::nullopt;
138138
int64_t stride = strides[strides.size() - 2];

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
9191
// Check if the last stride is non-unit and has a valid memory space.
9292
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
9393
const LLVMTypeConverter &converter) {
94-
if (!isLastMemrefDimUnitStride(memRefType))
94+
if (!memRefType.isLastDimUnitStride())
9595
return failure();
9696
if (failed(converter.getMemRefAddressSpace(memRefType)))
9797
return failure();
@@ -1374,7 +1374,7 @@ static std::optional<SmallVector<int64_t, 4>>
13741374
computeContiguousStrides(MemRefType memRefType) {
13751375
int64_t offset;
13761376
SmallVector<int64_t, 4> strides;
1377-
if (failed(getStridesAndOffset(memRefType, strides, offset)))
1377+
if (failed(memRefType.getStridesAndOffset(strides, offset)))
13781378
return std::nullopt;
13791379
if (!strides.empty() && strides.back() != 1)
13801380
return std::nullopt;

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
16501650
return failure();
16511651
if (xferOp.getVectorType().getRank() != 1)
16521652
return failure();
1653-
if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1653+
if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
16541654
return failure(); // Handled by ConvertVectorToLLVM
16551655

16561656
// Loop bounds, step, state...

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
7676
// Validate further transfer op semantics.
7777
SmallVector<int64_t> strides;
7878
int64_t offset;
79-
if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
80-
strides.back() != 1)
79+
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
8180
return rewriter.notifyMatchFailure(
8281
xferOp, "Buffer must be contiguous in the innermost dimension");
8382

@@ -105,7 +104,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
105104
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
106105
Operation::operand_range offsets) {
107106
MemRefType srcTy = src.getType();
108-
auto [strides, offset] = getStridesAndOffset(srcTy);
107+
auto [strides, offset] = srcTy.getStridesAndOffset();
109108

110109
xegpu::CreateNdDescOp ndDesc;
111110
if (srcTy.hasStaticShape()) {

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ static bool staticallyOutOfBounds(OpType op) {
129129
return false;
130130
int64_t offset;
131131
SmallVector<int64_t> strides;
132-
if (failed(getStridesAndOffset(bufferType, strides, offset)))
132+
if (failed(bufferType.getStridesAndOffset(strides, offset)))
133133
return false;
134134
int64_t result = offset + op.getIndexOffset().value_or(0);
135135
if (op.getSgprOffset()) {

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
5353
unsigned bytes = width >> 3;
5454
int64_t offset;
5555
SmallVector<int64_t, 4> strides;
56-
if (failed(getStridesAndOffset(mType, strides, offset)) ||
57-
strides.back() != 1)
56+
if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
5857
return failure();
5958
if (strides[preLast] == ShapedType::kDynamic) {
6059
// Dynamic stride needs code to compute the stride at runtime.

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
4242
auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
4343
int64_t sourceOffset, targetOffset;
4444
SmallVector<int64_t, 4> sourceStrides, targetStrides;
45-
if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
46-
failed(getStridesAndOffset(target, targetStrides, targetOffset)))
45+
if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
46+
failed(target.getStridesAndOffset(targetStrides, targetOffset)))
4747
return false;
4848
auto dynamicToStatic = [](int64_t a, int64_t b) {
4949
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);

mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
2929
static bool hasFullyDynamicLayoutMap(MemRefType type) {
3030
int64_t offset;
3131
SmallVector<int64_t, 4> strides;
32-
if (failed(getStridesAndOffset(type, strides, offset)))
32+
if (failed(type.getStridesAndOffset(strides, offset)))
3333
return false;
3434
if (!llvm::all_of(strides, ShapedType::isDynamic))
3535
return false;

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,7 +1903,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
19031903
auto operand = resMatrixType.getOperand();
19041904
auto srcMemrefType = llvm::cast<MemRefType>(srcType);
19051905

1906-
if (!isLastMemrefDimUnitStride(srcMemrefType))
1906+
if (!srcMemrefType.isLastDimUnitStride())
19071907
return emitError(
19081908
"expected source memref most minor dim must have unit stride");
19091909

@@ -1923,7 +1923,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
19231923
auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
19241924
auto dstMemrefType = llvm::cast<MemRefType>(dstType);
19251925

1926-
if (!isLastMemrefDimUnitStride(dstMemrefType))
1926+
if (!dstMemrefType.isLastDimUnitStride())
19271927
return emitError(
19281928
"expected destination memref most minor dim must have unit stride");
19291929

mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
6767
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
6868
}
6969

70-
auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
70+
auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
7171

7272
auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
7373
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)

0 commit comments

Comments
 (0)