Skip to content

[mlir][IR][NFC] Move free-standing functions to MemRefType #123465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticStrides");
auto [strides, offset] = getStridesAndOffset(memrefType);
auto [strides, offset] = memrefType.getStridesAndOffset();
// reuse the storage of ConstStridesAttr since strides from
// memref is not persistant
setConstStrides(strides);
Expand Down
45 changes: 0 additions & 45 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,33 +409,6 @@ inline bool TensorType::classof(Type type) {
// Type Utilities
//===----------------------------------------------------------------------===//

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with a layout map in strided form include:
/// 1. empty or identity layout map, in which case the stride information is
/// the canonical form computed from sizes;
/// 2. a StridedLayoutAttr layout;
/// 3. any other layout that be converted into a single affine map layout of
/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
/// symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along a
/// particular dimension.
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);

/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
/// int64_t) that will assert if the logical result is not succeeded.
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);

/// Return a version of `t` with identity layout if it can be determined
/// statically that the layout is the canonical contiguous strided layout.
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
/// `t` with simplified layout.
MemRefType canonicalizeStridedLayout(MemRefType t);

/// Given MemRef `sizes` that are either static or dynamic, returns the
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
/// once a dynamic dimension is encountered, all canonical strides become
Expand All @@ -458,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
MLIRContext *context);

/// Return "true" if the layout for `t` is compatible with strided semantics.
bool isStrided(MemRefType t);

/// Return "true" if the last dimension of the given type has a static unit
/// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type);

/// Return "true" if the last N dimensions of the given type are contiguous.
///
/// Examples:
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
/// considering both _all_ and _only_ the trailing 3 dims,
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
/// considering the trailing 3 dims.
///
bool trailingNDimsContiguous(MemRefType type, int64_t n);

} // namespace mlir

#endif // MLIR_IR_BUILTINTYPES_H
42 changes: 42 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -808,10 +808,52 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
/// Arguments that are passed into the builder must outlive the builder.
class Builder;

/// Return "true" if the last N dimensions are contiguous.
///
/// Examples:
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
/// considering both _all_ and _only_ the trailing 3 dims,
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
/// considering the trailing 3 dims.
///
bool areTrailingDimsContiguous(int64_t n);

/// Return a version of this type with identity layout if it can be
/// determined statically that the layout is the canonical contiguous
/// strided layout. Otherwise pass the layout into `simplifyAffineMap`
/// and return a copy of this type with simplified layout.
MemRefType canonicalizeStridedLayout();

/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with a layout map in strided form include:
/// 1. empty or identity layout map, in which case the stride information
/// is the canonical form computed from sizes;
/// 2. a StridedLayoutAttr layout;
/// 3. any other layout that be converted into a single affine map layout
/// of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are
/// constants or symbols.
///
/// A stride specification is a list of integer values that are either
/// static or dynamic (encoded with ShapedType::kDynamic). Strides encode
/// the distance in the number of elements between successive entries along
/// a particular dimension.
LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
int64_t &offset);

/// Wrapper around getStridesAndOffset(SmallVectorImpl<int64_t>, int64_t)
/// that will assert if the logical result is not succeeded.
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset();

/// Return "true" if the layout is compatible with strided semantics.
bool isStrided();

/// Return "true" if the last dimension has a static unit stride. Also
/// return "true" for types with no strides.
bool isLastDimUnitStride();
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ class StaticShapeMemRefOf<list<Type> allowedTypes> :
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;

// For a MemRefType, verify that it has strides.
def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>;
def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>;

class StridedMemRefOf<list<Type> allowedTypes> :
ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/CAPI/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
int64_t *offset) {
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
SmallVector<int64_t> strides_;
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
return mlirLogicalResultFailure();

(void)std::copy(strides_.begin(), strides_.end(), strides);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
// Construct buffer descriptor from memref, attributes
int64_t offset = 0;
SmallVector<int64_t, 5> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
if (failed(memrefType.getStridesAndOffset(strides, offset)))
return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");

MemRefDescriptor memrefDescriptor(memref);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
assert(type.hasStaticShape() && "unexpected dynamic shape");

// Extract all strides and offsets and verify they are static.
auto [strides, offset] = getStridesAndOffset(type);
auto [strides, offset] = type.getStridesAndOffset();
assert(!ShapedType::isDynamic(offset) && "expected static offset");
assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
"expected static strides");
Expand Down Expand Up @@ -193,7 +193,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
MemRefType type) {
// When we convert to LLVM, the input memref must have been normalized
// beforehand. Hence, this call is guaranteed to work.
auto [strides, offsetCst] = getStridesAndOffset(type);
auto [strides, offsetCst] = type.getStridesAndOffset();

Value ptr = alignedPtr(builder, loc);
// For zero offsets, we already have the base pointer.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
ConversionPatternRewriter &rewriter) const {

auto [strides, offset] = getStridesAndOffset(type);
auto [strides, offset] = type.getStridesAndOffset();

MemRefDescriptor memRefDescriptor(memRefDesc);
// Use a canonical representation of the start address so that later
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
SmallVector<Type, 5>
LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
bool unpackAggregates) const {
if (!isStrided(type)) {
if (!type.isStrided()) {
emitError(
UnknownLoc::get(type.getContext()),
"conversion to strided form failed either due to non-strided layout "
Expand Down Expand Up @@ -603,7 +603,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {

int64_t offset = 0;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memrefTy, strides, offset)))
if (failed(memrefTy.getStridesAndOffset(strides, offset)))
return false;

for (int64_t stride : strides)
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ struct MemRefReshapeOpLowering
// Extract the offset and strides from the type.
int64_t offset;
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
return rewriter.notifyMatchFailure(
reshapeOp, "failed to get stride and offset exprs");

Expand Down Expand Up @@ -1451,7 +1451,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {

int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
if (failed(successStrides))
return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
assert(offset == 0 && "expected offset to be 0");
Expand Down Expand Up @@ -1560,7 +1560,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
auto memRefType = atomicOp.getMemRefType();
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return failure();
auto dataPtr =
getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
return 0;
int64_t offset = 0;
SmallVector<int64_t, 2> strides;
if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return std::nullopt;
int64_t stride = strides[strides.size() - 2];
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
// Check if the last stride is non-unit and has a valid memory space.
static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
const LLVMTypeConverter &converter) {
if (!isLastMemrefDimUnitStride(memRefType))
if (!memRefType.isLastDimUnitStride())
return failure();
if (failed(converter.getMemRefAddressSpace(memRefType)))
return failure();
Expand Down Expand Up @@ -1374,7 +1374,7 @@ static std::optional<SmallVector<int64_t, 4>>
computeContiguousStrides(MemRefType memRefType) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memRefType, strides, offset)))
if (failed(memRefType.getStridesAndOffset(strides, offset)))
return std::nullopt;
if (!strides.empty() && strides.back() != 1)
return std::nullopt;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
return failure();
if (xferOp.getVectorType().getRank() != 1)
return failure();
if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
return failure(); // Handled by ConvertVectorToLLVM

// Loop bounds, step, state...
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
strides.back() != 1)
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");

Expand Down Expand Up @@ -105,7 +104,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Operation::operand_range offsets) {
MemRefType srcTy = src.getType();
auto [strides, offset] = getStridesAndOffset(srcTy);
auto [strides, offset] = srcTy.getStridesAndOffset();

xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ static bool staticallyOutOfBounds(OpType op) {
return false;
int64_t offset;
SmallVector<int64_t> strides;
if (failed(getStridesAndOffset(bufferType, strides, offset)))
if (failed(bufferType.getStridesAndOffset(strides, offset)))
return false;
int64_t result = offset + op.getIndexOffset().value_or(0);
if (op.getSgprOffset()) {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
unsigned bytes = width >> 3;
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(mType, strides, offset)) ||
strides.back() != 1)
if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
return failure();
if (strides[preLast] == ShapedType::kDynamic) {
// Dynamic stride needs code to compute the stride at runtime.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
int64_t sourceOffset, targetOffset;
SmallVector<int64_t, 4> sourceStrides, targetStrides;
if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
failed(getStridesAndOffset(target, targetStrides, targetOffset)))
if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
failed(target.getStridesAndOffset(targetStrides, targetOffset)))
return false;
auto dynamicToStatic = [](int64_t a, int64_t b) {
return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
static bool hasFullyDynamicLayoutMap(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(type, strides, offset)))
if (failed(type.getStridesAndOffset(strides, offset)))
return false;
if (!llvm::all_of(strides, ShapedType::isDynamic))
return false;
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1903,7 +1903,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
auto operand = resMatrixType.getOperand();
auto srcMemrefType = llvm::cast<MemRefType>(srcType);

if (!isLastMemrefDimUnitStride(srcMemrefType))
if (!srcMemrefType.isLastDimUnitStride())
return emitError(
"expected source memref most minor dim must have unit stride");

Expand All @@ -1923,7 +1923,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
auto dstMemrefType = llvm::cast<MemRefType>(dstType);

if (!isLastMemrefDimUnitStride(dstMemrefType))
if (!dstMemrefType.isLastDimUnitStride())
return emitError(
"expected destination memref most minor dim must have unit stride");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
}

auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();

auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
Expand Down
Loading
Loading