[Mlir-commits] [mlir] [mlir][IR][NFC] Move free-standing functions to `MemRefType` (PR #123465)

Matthias Springer llvmlistbot at llvm.org
Sat Jan 18 07:56:19 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/123465

Turn free-standing `MemRefType`-related helper functions in `BuiltinTypes.h` into member functions.


>From b68d260a79d46efd4dddc37bece9a95d8f7d51b1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 18 Jan 2025 12:47:10 +0100
Subject: [PATCH] [mlir][IR][NFC] Move free-standing functions to `MemRefType`

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   2 +-
 mlir/include/mlir/IR/BuiltinTypes.h           |  45 ----
 mlir/include/mlir/IR/BuiltinTypes.td          |  42 ++++
 mlir/include/mlir/IR/CommonTypeConstraints.td |   2 +-
 mlir/lib/CAPI/IR/BuiltinTypes.cpp             |   2 +-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |   2 +-
 .../Conversion/LLVMCommon/MemRefBuilder.cpp   |   4 +-
 mlir/lib/Conversion/LLVMCommon/Pattern.cpp    |   2 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   4 +-
 .../Conversion/MemRefToLLVM/MemRefToLLVM.cpp  |   6 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |   2 +-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |   4 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |   2 +-
 .../VectorToXeGPU/VectorToXeGPU.cpp           |   5 +-
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  |   2 +-
 .../AMX/Transforms/LegalizeForLLVMExport.cpp  |   3 +-
 .../Bufferization/IR/BufferizationOps.cpp     |   4 +-
 .../Transforms/BufferResultsToOutParams.cpp   |   2 +-
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        |   4 +-
 .../GPU/Transforms/DecomposeMemRefs.cpp       |   2 +-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  46 ++--
 .../MemRef/Transforms/EmulateNarrowType.cpp   |   2 +-
 .../Transforms/ExpandStridedMetadata.cpp      |  14 +-
 .../Transforms/RuntimeOpVerification.cpp      |   2 +-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp |   2 +-
 mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp    |   4 +-
 .../NVGPU/Transforms/CreateAsyncGroups.cpp    |   4 +-
 mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp     |   4 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      |   6 +-
 .../BufferizableOpInterfaceImpl.cpp           |   2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      |   4 +-
 .../Vector/Transforms/LowerVectorTransfer.cpp |   4 +-
 .../Transforms/VectorTransferOpTransforms.cpp |   6 +-
 .../VectorTransferSplitRewritePatterns.cpp    |   4 +-
 .../Vector/Transforms/VectorTransforms.cpp    |   2 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |   2 +-
 mlir/lib/IR/BuiltinTypes.cpp                  | 217 +++++++++---------
 .../Analysis/TestMemRefStrideCalculation.cpp  |   2 +-
 38 files changed, 228 insertions(+), 240 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 5910aa3f7f2dae..f5cf3dad75d9c2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -198,7 +198,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
 
       auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
       assert(memrefType && "Incorrect use of getStaticStrides");
-      auto [strides, offset] = getStridesAndOffset(memrefType);
+      auto [strides, offset] = memrefType.getStridesAndOffset();
       // reuse the storage of ConstStridesAttr since strides from
       // memref is not persistant
       setConstStrides(strides);
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 19c5361124aacb..df1e02732617d2 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -409,33 +409,6 @@ inline bool TensorType::classof(Type type) {
 // Type Utilities
 //===----------------------------------------------------------------------===//
 
-/// Returns the strides of the MemRef if the layout map is in strided form.
-/// MemRefs with a layout map in strided form include:
-///   1. empty or identity layout map, in which case the stride information is
-///      the canonical form computed from sizes;
-///   2. a StridedLayoutAttr layout;
-///   3. any other layout that be converted into a single affine map layout of
-///      the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
-///      symbols.
-///
-/// A stride specification is a list of integer values that are either static
-/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
-/// the distance in the number of elements between successive entries along a
-/// particular dimension.
-LogicalResult getStridesAndOffset(MemRefType t,
-                                  SmallVectorImpl<int64_t> &strides,
-                                  int64_t &offset);
-
-/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
-/// int64_t) that will assert if the logical result is not succeeded.
-std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
-
-/// Return a version of `t` with identity layout if it can be determined
-/// statically that the layout is the canonical contiguous strided layout.
-/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
-/// `t` with simplified layout.
-MemRefType canonicalizeStridedLayout(MemRefType t);
-
 /// Given MemRef `sizes` that are either static or dynamic, returns the
 /// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
 /// once a dynamic dimension is encountered, all canonical strides become
@@ -458,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
 /// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
 AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                           MLIRContext *context);
-
-/// Return "true" if the layout for `t` is compatible with strided semantics.
-bool isStrided(MemRefType t);
-
-/// Return "true" if the last dimension of the given type has a static unit
-/// stride. Also return "true" for types with no strides.
-bool isLastMemrefDimUnitStride(MemRefType type);
-
-/// Return "true" if the last N dimensions of the given type are contiguous.
-///
-/// Examples:
-///   - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
-///   considering both _all_ and _only_ the trailing 3 dims,
-///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
-///   considering the trailing 3 dims.
-///
-bool trailingNDimsContiguous(MemRefType type, int64_t n);
-
 } // namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4f09d2e41e7ceb..e5a2ae81da0c9a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -808,10 +808,52 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
 
+    /// Return "true" if the last N dimensions are contiguous.
+    ///
+    /// Examples:
+    ///   - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+    ///   considering both _all_ and _only_ the trailing 3 dims,
+    ///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+    ///   considering the trailing 3 dims.
+    ///
+    bool areTrailingDimsContiguous(int64_t n);
+
+    /// Return a version of this type with identity layout if it can be
+    /// determined statically that the layout is the canonical contiguous
+    /// strided layout. Otherwise pass the layout into `simplifyAffineMap`
+    /// and return a copy of this type with simplified layout.
+    MemRefType canonicalizeStridedLayout();
+
     /// [deprecated] Returns the memory space in old raw integer representation.
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
 
+    /// Returns the strides of the MemRef if the layout map is in strided form.
+    /// MemRefs with a layout map in strided form include:
+    ///   1. empty or identity layout map, in which case the stride information
+    ///      is the canonical form computed from sizes;
+    ///   2. a StridedLayoutAttr layout;
+    ///   3. any other layout that be converted into a single affine map layout
+    ///      of the form `K + k0 * d0 + ... kn * dn`, where K and ki's are
+    ///      constants or symbols.
+    ///
+    /// A stride specification is a list of integer values that are either
+    /// static or dynamic (encoded with ShapedType::kDynamic). Strides encode
+    /// the distance in the number of elements between successive entries along
+    /// a particular dimension.
+    LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
+                                      int64_t &offset);
+
+    /// Wrapper around getStridesAndOffset(SmallVectorImpl<int64_t>, int64_t)
+    /// that will assert if the logical result is not succeeded.
+    std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset();
+
+    /// Return "true" if the layout is compatible with strided semantics.
+    bool isStrided();
+
+    /// Return "true" if the last dimension has a static unit stride. Also
+    /// return "true" for types with no strides.
+    bool isLastDimUnitStride();
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 6f52195c1d7c92..64270469cad78d 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -820,7 +820,7 @@ class StaticShapeMemRefOf<list<Type> allowedTypes> :
 def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
 
 // For a MemRefType, verify that it has strides.
-def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>;
+def HasStridesPred : CPred<[{ ::llvm::cast<::mlir::MemRefType>($_self).isStrided() }]>;
 
 class StridedMemRefOf<list<Type> allowedTypes> :
     ConfinedType<MemRefOf<allowedTypes>, [HasStridesPred],
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 250e4a6bbf8dfd..26feaf79149b23 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -514,7 +514,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
                                                     int64_t *offset) {
   MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
   SmallVector<int64_t> strides_;
-  if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
+  if (failed(memrefType.getStridesAndOffset(strides_, *offset)))
     return mlirLogicalResultFailure();
 
   (void)std::copy(strides_.begin(), strides_.end(), strides);
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1564e417a7a48e..9c69da9eb0c6e9 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -192,7 +192,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     // Construct buffer descriptor from memref, attributes
     int64_t offset = 0;
     SmallVector<int64_t, 5> strides;
-    if (failed(getStridesAndOffset(memrefType, strides, offset)))
+    if (failed(memrefType.getStridesAndOffset(strides, offset)))
       return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
 
     MemRefDescriptor memrefDescriptor(memref);
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index 19c3ba1f950202..63f99eb744a83b 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -52,7 +52,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
   assert(type.hasStaticShape() && "unexpected dynamic shape");
 
   // Extract all strides and offsets and verify they are static.
-  auto [strides, offset] = getStridesAndOffset(type);
+  auto [strides, offset] = type.getStridesAndOffset();
   assert(!ShapedType::isDynamic(offset) && "expected static offset");
   assert(!llvm::any_of(strides, ShapedType::isDynamic) &&
          "expected static strides");
@@ -193,7 +193,7 @@ Value MemRefDescriptor::bufferPtr(OpBuilder &builder, Location loc,
                                   MemRefType type) {
   // When we convert to LLVM, the input memref must have been normalized
   // beforehand. Hence, this call is guaranteed to work.
-  auto [strides, offsetCst] = getStridesAndOffset(type);
+  auto [strides, offsetCst] = type.getStridesAndOffset();
 
   Value ptr = alignedPtr(builder, loc);
   // For zero offsets, we already have the base pointer.
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index d551506485a454..a47a2872ceb073 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -62,7 +62,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
     Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
     ConversionPatternRewriter &rewriter) const {
 
-  auto [strides, offset] = getStridesAndOffset(type);
+  auto [strides, offset] = type.getStridesAndOffset();
 
   MemRefDescriptor memRefDescriptor(memRefDesc);
   // Use a canonical representation of the start address so that later
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 64bdb248dff430..0df6502ff4c1fc 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -486,7 +486,7 @@ LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) const {
 SmallVector<Type, 5>
 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
                                              bool unpackAggregates) const {
-  if (!isStrided(type)) {
+  if (!type.isStrided()) {
     emitError(
         UnknownLoc::get(type.getContext()),
         "conversion to strided form failed either due to non-strided layout "
@@ -604,7 +604,7 @@ bool LLVMTypeConverter::canConvertToBarePtr(BaseMemRefType type) {
 
   int64_t offset = 0;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(memrefTy, strides, offset)))
+  if (failed(memrefTy.getStridesAndOffset(strides, offset)))
     return false;
 
   for (int64_t stride : strides)
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 86f687d7f2636e..f7542b8b3bc5c7 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1136,7 +1136,7 @@ struct MemRefReshapeOpLowering
       // Extract the offset and strides from the type.
       int64_t offset;
       SmallVector<int64_t> strides;
-      if (failed(getStridesAndOffset(targetMemRefType, strides, offset)))
+      if (failed(targetMemRefType.getStridesAndOffset(strides, offset)))
         return rewriter.notifyMatchFailure(
             reshapeOp, "failed to get stride and offset exprs");
 
@@ -1451,7 +1451,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
 
     int64_t offset;
     SmallVector<int64_t, 4> strides;
-    auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+    auto successStrides = viewMemRefType.getStridesAndOffset(strides, offset);
     if (failed(successStrides))
       return viewOp.emitWarning("cannot cast to non-strided shape"), failure();
     assert(offset == 0 && "expected offset to be 0");
@@ -1560,7 +1560,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<memref::AtomicRMWOp> {
     auto memRefType = atomicOp.getMemRefType();
     SmallVector<int64_t> strides;
     int64_t offset;
-    if (failed(getStridesAndOffset(memRefType, strides, offset)))
+    if (failed(memRefType.getStridesAndOffset(strides, offset)))
       return failure();
     auto dataPtr =
         getStridedElementPtr(atomicOp.getLoc(), memRefType, adaptor.getMemref(),
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 5b4414d67fdac0..eaefe9e3857933 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -132,7 +132,7 @@ static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
     return 0;
   int64_t offset = 0;
   SmallVector<int64_t, 2> strides;
-  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
+  if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
       strides.back() != 1)
     return std::nullopt;
   int64_t stride = strides[strides.size() - 2];
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d688d8e2ab6588..a1e21cb524bd9a 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -91,7 +91,7 @@ LogicalResult getMemRefAlignment(const LLVMTypeConverter &typeConverter,
 // Check if the last stride is non-unit and has a valid memory space.
 static LogicalResult isMemRefTypeSupported(MemRefType memRefType,
                                            const LLVMTypeConverter &converter) {
-  if (!isLastMemrefDimUnitStride(memRefType))
+  if (!memRefType.isLastDimUnitStride())
     return failure();
   if (failed(converter.getMemRefAddressSpace(memRefType)))
     return failure();
@@ -1374,7 +1374,7 @@ static std::optional<SmallVector<int64_t, 4>>
 computeContiguousStrides(MemRefType memRefType) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(memRefType, strides, offset)))
+  if (failed(memRefType.getStridesAndOffset(strides, offset)))
     return std::nullopt;
   if (!strides.empty() && strides.back() != 1)
     return std::nullopt;
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 01bc65c841e94c..22bf27d229ce5d 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1650,7 +1650,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
       return failure();
     if (xferOp.getVectorType().getRank() != 1)
       return failure();
-    if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
+    if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
       return failure(); // Handled by ConvertVectorToLLVM
 
     // Loop bounds, step, state...
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8041bdf7da19b3..d3229d2e912966 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -76,8 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   // Validate further transfer op semantics.
   SmallVector<int64_t> strides;
   int64_t offset;
-  if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
-      strides.back() != 1)
+  if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
     return rewriter.notifyMatchFailure(
         xferOp, "Buffer must be contiguous in the innermost dimension");
 
@@ -105,7 +104,7 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
                    xegpu::TensorDescType descType, TypedValue<MemRefType> src,
                    Operation::operand_range offsets) {
   MemRefType srcTy = src.getType();
-  auto [strides, offset] = getStridesAndOffset(srcTy);
+  auto [strides, offset] = srcTy.getStridesAndOffset();
 
   xegpu::CreateNdDescOp ndDesc;
   if (srcTy.hasStaticShape()) {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 492e4781f57810..dfd5d7e212f2f7 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -129,7 +129,7 @@ static bool staticallyOutOfBounds(OpType op) {
     return false;
   int64_t offset;
   SmallVector<int64_t> strides;
-  if (failed(getStridesAndOffset(bufferType, strides, offset)))
+  if (failed(bufferType.getStridesAndOffset(strides, offset)))
     return false;
   int64_t result = offset + op.getIndexOffset().value_or(0);
   if (op.getSgprOffset()) {
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 4eac371d4c1ae4..4cb777b03b1963 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -53,8 +53,7 @@ FailureOr<Value> getStride(ConversionPatternRewriter &rewriter,
   unsigned bytes = width >> 3;
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(mType, strides, offset)) ||
-      strides.back() != 1)
+  if (failed(mType.getStridesAndOffset(strides, offset)) || strides.back() != 1)
     return failure();
   if (strides[preLast] == ShapedType::kDynamic) {
     // Dynamic stride needs code to compute the stride at runtime.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index f1841b860ff81a..6be55a1d282240 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -42,8 +42,8 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
   auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
     int64_t sourceOffset, targetOffset;
     SmallVector<int64_t, 4> sourceStrides, targetStrides;
-    if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
-        failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+    if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
+        failed(target.getStridesAndOffset(targetStrides, targetOffset)))
       return false;
     auto dynamicToStatic = [](int64_t a, int64_t b) {
       return ShapedType::isDynamic(a) && !ShapedType::isDynamic(b);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
index 2502744cb3f580..ce0f112dc2dd22 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp
@@ -29,7 +29,7 @@ using MemCpyFn = bufferization::BufferResultsToOutParamsOpts::MemCpyFn;
 static bool hasFullyDynamicLayoutMap(MemRefType type) {
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(type, strides, offset)))
+  if (failed(type.getStridesAndOffset(strides, offset)))
     return false;
   if (!llvm::all_of(strides, ShapedType::isDynamic))
     return false;
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 49209229259a73..301066e7d3e1f8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1903,7 +1903,7 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
   auto operand = resMatrixType.getOperand();
   auto srcMemrefType = llvm::cast<MemRefType>(srcType);
 
-  if (!isLastMemrefDimUnitStride(srcMemrefType))
+  if (!srcMemrefType.isLastDimUnitStride())
     return emitError(
         "expected source memref most minor dim must have unit stride");
 
@@ -1923,7 +1923,7 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
   auto srcMatrixType = llvm::cast<gpu::MMAMatrixType>(srcType);
   auto dstMemrefType = llvm::cast<MemRefType>(dstType);
 
-  if (!isLastMemrefDimUnitStride(dstMemrefType))
+  if (!dstMemrefType.isLastDimUnitStride())
     return emitError(
         "expected destination memref most minor dim must have unit stride");
 
diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
index a504101fb3f2fc..2afdeff3a7be11 100644
--- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp
@@ -67,7 +67,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
   }
 
-  auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
 
   auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
     return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 9aae46a5c288dc..4f75b7618d6367 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -163,7 +163,7 @@ static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
   SmallVector<int64_t> strides;
   int64_t offset;
   LogicalResult hasStaticInformation =
-      getStridesAndOffset(memrefType, strides, offset);
+      memrefType.getStridesAndOffset(strides, offset);
   if (failed(hasStaticInformation))
     return SmallVector<int64_t>();
   return SmallVector<int64_t>(1, offset);
@@ -176,7 +176,7 @@ static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
   SmallVector<int64_t> strides;
   int64_t offset;
   LogicalResult hasStaticInformation =
-      getStridesAndOffset(memrefType, strides, offset);
+      memrefType.getStridesAndOffset(strides, offset);
   if (failed(hasStaticInformation))
     return SmallVector<int64_t>();
   return strides;
@@ -663,8 +663,8 @@ bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
   // Only fold casts between strided memref forms.
   int64_t sourceOffset, resultOffset;
   SmallVector<int64_t, 4> sourceStrides, resultStrides;
-  if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
-      failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
+      failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
     return false;
 
   // If cast is towards more static sizes along any dimension, don't fold.
@@ -708,8 +708,8 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
     if (aT.getLayout() != bT.getLayout()) {
       int64_t aOffset, bOffset;
       SmallVector<int64_t, 4> aStrides, bStrides;
-      if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
-          failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+      if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
+          failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
           aStrides.size() != bStrides.size())
         return false;
 
@@ -954,9 +954,9 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
   SmallVector<int64_t> originalStrides, candidateStrides;
   int64_t originalOffset, candidateOffset;
   if (failed(
-          getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
+          originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
       failed(
-          getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
+          reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
     return failure();
 
   // For memrefs, a dimension is truly dropped if its corresponding stride is
@@ -1903,7 +1903,7 @@ LogicalResult ReinterpretCastOp::verify() {
   // identity layout.
   int64_t resultOffset;
   SmallVector<int64_t, 4> resultStrides;
-  if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+  if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
     return emitError("expected result type to have strided layout but found ")
            << resultType;
 
@@ -2223,7 +2223,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
                          ArrayRef<ReassociationIndices> reassociation) {
   int64_t srcOffset;
   SmallVector<int64_t> srcStrides;
-  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
     return failure();
   assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
 
@@ -2420,7 +2420,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
   int64_t srcOffset;
   SmallVector<int64_t> srcStrides;
   auto srcShape = srcType.getShape();
-  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
     return failure();
 
   // The result stride of a reassociation group is the stride of the last entry
@@ -2706,7 +2706,7 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   assert(staticStrides.size() == rank && "staticStrides length mismatch");
 
   // Extract source offset and strides.
-  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceMemRefType);
+  auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
 
   // Compute target offset whose value is:
   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
@@ -2912,8 +2912,8 @@ Value SubViewOp::getViewSource() { return getSource(); }
 static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
   int64_t t1Offset, t2Offset;
   SmallVector<int64_t> t1Strides, t2Strides;
-  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
-  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
+  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
+  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
   return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
 }
 
@@ -2928,8 +2928,8 @@ static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
          "incorrect number of dropped dims");
   int64_t t1Offset, t2Offset;
   SmallVector<int64_t> t1Strides, t2Strides;
-  auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
-  auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
+  auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
+  auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
   if (failed(res1) || failed(res2))
     return false;
   for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
@@ -2980,7 +2980,7 @@ LogicalResult SubViewOp::verify() {
            << baseType << " and subview memref type " << subViewType;
 
   // Verify that the base memref type has a strided layout map.
-  if (!isStrided(baseType))
+  if (!baseType.isStrided())
     return emitError("base type ") << baseType << " is not strided";
 
   // Compute the expected result type, assuming that there are no rank
@@ -3261,7 +3261,7 @@ struct SubViewReturnTypeCanonicalizer {
       return nonReducedType;
 
     // Take the strides and offset from the non-rank reduced type.
-    auto [nonReducedStrides, offset] = getStridesAndOffset(nonReducedType);
+    auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
 
     // Drop dims from shape and strides.
     SmallVector<int64_t> targetShape;
@@ -3341,7 +3341,7 @@ void TransposeOp::getAsmResultNames(
 static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
   auto originalSizes = memRefType.getShape();
-  auto [originalStrides, offset] = getStridesAndOffset(memRefType);
+  auto [originalStrides, offset] = memRefType.getStridesAndOffset();
   assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
 
   // Compute permuted sizes and strides.
@@ -3400,10 +3400,10 @@ LogicalResult TransposeOp::verify() {
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
   auto resultType = llvm::cast<MemRefType>(getType());
-  auto canonicalResultType = canonicalizeStridedLayout(
-      inferTransposeResultType(srcType, getPermutation()));
+  auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
+                                 .canonicalizeStridedLayout();
 
-  if (canonicalizeStridedLayout(resultType) != canonicalResultType)
+  if (resultType.canonicalizeStridedLayout() != canonicalResultType)
     return emitOpError("result type ")
            << resultType
            << " is not equivalent to the canonical transposed input type "
@@ -3483,7 +3483,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
     // Get offset from old memref view type 'memRefType'.
     int64_t oldOffset;
     SmallVector<int64_t, 4> oldStrides;
-    if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
+    if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
       return failure();
     assert(oldOffset == 0 && "Expected 0 offset");
 
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 28f9061d9873b7..f58385a7777dbc 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -632,7 +632,7 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         // Currently only handle innermost stride being 1, checking
         SmallVector<int64_t> strides;
         int64_t offset;
-        if (failed(getStridesAndOffset(ty, strides, offset)))
+        if (failed(ty.getStridesAndOffset(strides, offset)))
           return nullptr;
         if (!strides.empty() && strides.back() != 1)
           return nullptr;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index aa008f8407b5d3..b69cbabe0dde97 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -68,9 +68,9 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
   auto newExtractStridedMetadata =
       rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
-  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+  auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
 #ifndef NDEBUG
-  auto [resultStrides, resultOffset] = getStridesAndOffset(subview.getType());
+  auto [resultStrides, resultOffset] = subview.getType().getStridesAndOffset();
 #endif // NDEBUG
 
   // Compute the new strides and offset from the base strides and offset:
@@ -363,7 +363,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
   // Collect the statically known information about the original stride.
   Value source = expandShape.getSrc();
   auto sourceType = cast<MemRefType>(source.getType());
-  auto [strides, offset] = getStridesAndOffset(sourceType);
+  auto [strides, offset] = sourceType.getStridesAndOffset();
 
   OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
                                 ? origStrides[groupId]
@@ -503,7 +503,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
   Value source = collapseShape.getSrc();
   auto sourceType = cast<MemRefType>(source.getType());
 
-  auto [strides, offset] = getStridesAndOffset(sourceType);
+  auto [strides, offset] = sourceType.getStridesAndOffset();
 
   SmallVector<OpFoldResult> groupStrides;
   ArrayRef<int64_t> srcShape = sourceType.getShape();
@@ -528,7 +528,7 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
     // but we still have to make the type system happy.
     MemRefType collapsedType = collapseShape.getResultType();
     auto [collapsedStrides, collapsedOffset] =
-        getStridesAndOffset(collapsedType);
+        collapsedType.getStridesAndOffset();
     int64_t finalStride = collapsedStrides[groupId];
     if (ShapedType::isDynamic(finalStride)) {
       // Look for a dynamic stride. At this point we don't know which one is
@@ -581,7 +581,7 @@ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
       rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
 
   // Collect statically known information.
-  auto [strides, offset] = getStridesAndOffset(sourceType);
+  auto [strides, offset] = sourceType.getStridesAndOffset();
   MemRefType reshapeType = reshape.getResultType();
   unsigned reshapeRank = reshapeType.getRank();
 
@@ -1068,7 +1068,7 @@ class ExtractStridedMetadataOpCastFolder
                  : ofr;
     };
 
-    auto [sourceStrides, sourceOffset] = getStridesAndOffset(memrefType);
+    auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset();
     assert(sourceStrides.size() == rank && "unexpected number of strides");
 
     // Register the new offset.
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 450bfa0cec0c7f..f93ae0a7a298f0 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -91,7 +91,7 @@ struct CastOpInterface
     // Get result offset and strides.
     int64_t resultOffset;
     SmallVector<int64_t> resultStrides;
-    if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+    if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
       return;
 
     // Check offset.
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6de744a7f75244..270b43100a3a74 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -27,7 +27,7 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
 
   SmallVector<int64_t> strides;
   int64_t offset;
-  if (failed(getStridesAndOffset(type, strides, offset)))
+  if (failed(type.getStridesAndOffset(strides, offset)))
     return false;
 
   // MemRef is contiguous if outer dimensions are size-1 and inner
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index a027350e8a5f70..ea662941d23da3 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -70,9 +70,9 @@ LogicalResult DeviceAsyncCopyOp::verify() {
   auto srcMemref = llvm::cast<MemRefType>(getSrc().getType());
   auto dstMemref = llvm::cast<MemRefType>(getDst().getType());
 
-  if (!isLastMemrefDimUnitStride(srcMemref))
+  if (!srcMemref.isLastDimUnitStride())
     return emitError("source memref most minor dim must have unit stride");
-  if (!isLastMemrefDimUnitStride(dstMemref))
+  if (!dstMemref.isLastDimUnitStride())
     return emitError("destination memref most minor dim must have unit stride");
   if (!NVGPUDialect::hasSharedMemoryAddressSpace(dstMemref))
     return emitError()
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
index f8c699c65fe49e..10bc1993ffd960 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp
@@ -24,8 +24,8 @@ template <typename OpTy>
 static bool isContiguousXferOp(OpTy op) {
   return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
          op.hasPureBufferSemantics() &&
-         isLastMemrefDimUnitStride(
-             cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
+         cast<MemRefType>(nvgpu::getMemrefOperand(op).getType())
+             .isLastDimUnitStride();
 }
 
 /// Return "true" if the given op is a contiguous and suitable
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index c500815857ca5b..39cca7d363e0df 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -296,7 +296,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
   // Check that the last dimension of the read is contiguous. Note that it is
   // possible to expand support for this by scalarizing all the loads during
   // conversion.
-  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+  auto [strides, offset] = sourceType.getStridesAndOffset();
   return strides.back() == 1;
 }
 
@@ -320,6 +320,6 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
   // Check that the last dimension of the target memref is contiguous. Note that
   // it is possible to expand support for this by scalarizing all the stores
   // during conversion.
-  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
+  auto [strides, offset] = sourceType.getStridesAndOffset();
   return strides.back() == 1;
 }
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 29f7e8afe0773b..c56dbcca2175d4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -206,7 +206,7 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     int64_t offset;
     SmallVector<int64_t, 4> strides;
     if (!memRefType.hasStaticShape() ||
-        failed(getStridesAndOffset(memRefType, strides, offset)))
+        failed(memRefType.getStridesAndOffset(strides, offset)))
       return std::nullopt;
 
     // To get the size of the memref object in memory, the total size is the
@@ -1225,7 +1225,7 @@ Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter,
 
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
       llvm::is_contained(strides, ShapedType::kDynamic) ||
       ShapedType::isDynamic(offset)) {
     return nullptr;
@@ -1256,7 +1256,7 @@ Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter,
 
   int64_t offset;
   SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
+  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
       llvm::is_contained(strides, ShapedType::kDynamic) ||
       ShapedType::isDynamic(offset)) {
     return nullptr;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 1abcacd6d6db3d..ed3ba321b37ab9 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -186,7 +186,7 @@ struct CollapseShapeOpInterface
         // the source type.
         SmallVector<int64_t> strides;
         int64_t offset;
-        if (failed(getStridesAndOffset(bufferType, strides, offset)))
+        if (failed(bufferType.getStridesAndOffset(strides, offset)))
           return failure();
         resultType = MemRefType::get(
             {}, tensorResultType.getElementType(),
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae1cf95732336a..76bbe479855b62 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4974,7 +4974,7 @@ static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
       (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
     return success();
 
-  if (!isLastMemrefDimUnitStride(memRefTy))
+  if (!memRefTy.isLastDimUnitStride())
     return op->emitOpError("most minor memref dim must have unit stride");
   return success();
 }
@@ -5789,7 +5789,7 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult TypeCastOp::verify() {
-  MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
+  MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
   if (!canonicalType.getLayout().isIdentity())
     return emitOpError("expects operand to be a memref with identity layout");
   if (!getResultMemRefType().getLayout().isIdentity())
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index f9428a4ce28640..314dc44134e049 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -435,7 +435,7 @@ struct TransferReadToVectorLoadLowering
       return rewriter.notifyMatchFailure(read, "not a memref source");
 
     // Non-unit strides are handled by VectorToSCF.
-    if (!isLastMemrefDimUnitStride(memRefType))
+    if (!memRefType.isLastDimUnitStride())
       return rewriter.notifyMatchFailure(read, "!= 1 stride needs VectorToSCF");
 
     // If there is broadcasting involved then we first load the unbroadcasted
@@ -588,7 +588,7 @@ struct TransferWriteToVectorStoreLowering
       });
 
     // Non-unit strides are handled by VectorToSCF.
-    if (!isLastMemrefDimUnitStride(memRefType))
+    if (!memRefType.isLastDimUnitStride())
       return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) {
         diag << "most minor stride is not 1: " << write;
       });
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index b0892d16969d29..5871d6dd5b3e6d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -267,7 +267,7 @@ static MemRefType dropUnitDims(MemRefType inputType,
   auto targetShape = getReducedShape(sizes);
   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
       targetShape, inputType, offsets, sizes, strides);
-  return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
+  return cast<MemRefType>(rankReducedType).canonicalizeStridedLayout();
 }
 
 /// Creates a rank-reducing memref.subview op that drops unit dims from its
@@ -283,8 +283,8 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
                                     rewriter.getIndexAttr(1));
   MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
 
-  if (canonicalizeStridedLayout(resultType) ==
-      canonicalizeStridedLayout(inputType))
+  if (resultType.canonicalizeStridedLayout() ==
+      inputType.canonicalizeStridedLayout())
     return input;
   return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
                                             sizes, strides);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
index ee622e886f6185..66c23dd6e74950 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp
@@ -145,8 +145,8 @@ static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
     return MemRefType();
   int64_t aOffset, bOffset;
   SmallVector<int64_t, 4> aStrides, bStrides;
-  if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
-      failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
+  if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
+      failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
       aStrides.size() != bStrides.size())
     return MemRefType();
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 21ec718efd6a7a..84c1deaebcd009 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1243,7 +1243,7 @@ static FailureOr<size_t>
 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
   SmallVector<int64_t> srcStrides;
   int64_t srcOffset;
-  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+  if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
     return failure();
 
   auto isUnitDim = [](VectorType type, int dim) {
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index e590d8c43c44b0..7b56cd0cf0e912 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -261,7 +261,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
   auto vecRank = vectorType.getRank();
 
-  if (!trailingNDimsContiguous(memrefType, vecRank))
+  if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
   // Extract the trailing dims and strides of the input memref
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index bd1163bddf7ee0..3924d082f06280 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -645,24 +645,74 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// UnrankedMemRefType
-//===----------------------------------------------------------------------===//
+bool MemRefType::areTrailingDimsContiguous(int64_t n) {
+  if (!isLastDimUnitStride())
+    return false;
 
-unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
-  return detail::getMemorySpaceAsInt(getMemorySpace());
+  auto memrefShape = getShape().take_back(n);
+  if (ShapedType::isDynamicShape(memrefShape))
+    return false;
+
+  if (getLayout().isIdentity())
+    return true;
+
+  int64_t offset;
+  SmallVector<int64_t> stridesFull;
+  if (!succeeded(getStridesAndOffset(stridesFull, offset)))
+    return false;
+  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+  if (strides.empty())
+    return true;
+
+  // Check whether strides match "flattened" dims.
+  SmallVector<int64_t> flattenedDims;
+  auto dimProduct = 1;
+  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
+    dimProduct *= dim;
+    flattenedDims.push_back(dimProduct);
+  }
+
+  strides = strides.drop_back(1);
+  return llvm::equal(strides, llvm::reverse(flattenedDims));
 }
 
-LogicalResult
-UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
-                           Type elementType, Attribute memorySpace) {
-  if (!BaseMemRefType::isValidElementType(elementType))
-    return emitError() << "invalid memref element type";
+MemRefType MemRefType::canonicalizeStridedLayout() {
+  AffineMap m = getLayout().getAffineMap();
 
-  if (!isSupportedMemorySpace(memorySpace))
-    return emitError() << "unsupported memory space Attribute";
+  // Already in canonical form.
+  if (m.isIdentity())
+    return *this;
 
-  return success();
+  // Can't reduce to canonical identity form, return in canonical form.
+  if (m.getNumResults() > 1)
+    return *this;
+
+  // Corner-case for 0-D affine maps.
+  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
+    if (auto cst = llvm::dyn_cast<AffineConstantExpr>(m.getResult(0)))
+      if (cst.getValue() == 0)
+        return MemRefType::Builder(*this).setLayout({});
+    return *this;
+  }
+
+  // 0-D corner case for empty shape that still have an affine map. Example:
+  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
+  // offset needs to remain, just return t.
+  if (getShape().empty())
+    return *this;
+
+  // If the canonical strided layout for the sizes of `t` is equal to the
+  // simplified layout of `t` we can just return an empty layout. Otherwise,
+  // just simplify the existing layout.
+  AffineExpr expr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
+  auto simplifiedLayoutExpr =
+      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+  if (expr != simplifiedLayoutExpr)
+    return MemRefType::Builder(*this).setLayout(
+        AffineMapAttr::get(AffineMap::get(m.getNumDims(), m.getNumSymbols(),
+                                          simplifiedLayoutExpr)));
+  return MemRefType::Builder(*this).setLayout({});
 }
 
 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
@@ -783,11 +833,10 @@ static LogicalResult getStridesAndOffset(MemRefType t,
   return success();
 }
 
-LogicalResult mlir::getStridesAndOffset(MemRefType t,
-                                        SmallVectorImpl<int64_t> &strides,
-                                        int64_t &offset) {
+LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
+                                              int64_t &offset) {
   // Happy path: the type uses the strided layout directly.
-  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(t.getLayout())) {
+  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
     llvm::append_range(strides, strided.getStrides());
     offset = strided.getOffset();
     return success();
@@ -797,14 +846,14 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   // convertible to affine maps.
   AffineExpr offsetExpr;
   SmallVector<AffineExpr, 4> strideExprs;
-  if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
+  if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
     return failure();
-  if (auto cst = dyn_cast<AffineConstantExpr>(offsetExpr))
+  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
     offset = cst.getValue();
   else
     offset = ShapedType::kDynamic;
   for (auto e : strideExprs) {
-    if (auto c = dyn_cast<AffineConstantExpr>(e))
+    if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
       strides.push_back(c.getValue());
     else
       strides.push_back(ShapedType::kDynamic);
@@ -812,16 +861,49 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
   return success();
 }
 
-std::pair<SmallVector<int64_t>, int64_t>
-mlir::getStridesAndOffset(MemRefType t) {
+std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
   SmallVector<int64_t> strides;
   int64_t offset;
-  LogicalResult status = getStridesAndOffset(t, strides, offset);
+  LogicalResult status = getStridesAndOffset(strides, offset);
   (void)status;
   assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset");
   return {strides, offset};
 }
 
+bool MemRefType::isStrided() {
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto res = getStridesAndOffset(strides, offset);
+  return succeeded(res);
+}
+
+bool MemRefType::isLastDimUnitStride() {
+  int64_t offset;
+  SmallVector<int64_t> strides;
+  auto successStrides = getStridesAndOffset(strides, offset);
+  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
+}
+
+//===----------------------------------------------------------------------===//
+// UnrankedMemRefType
+//===----------------------------------------------------------------------===//
+
+unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
+  return detail::getMemorySpaceAsInt(getMemorySpace());
+}
+
+LogicalResult
+UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
+                           Type elementType, Attribute memorySpace) {
+  if (!BaseMemRefType::isValidElementType(elementType))
+    return emitError() << "invalid memref element type";
+
+  if (!isSupportedMemorySpace(memorySpace))
+    return emitError() << "unsupported memory space Attribute";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 /// TupleType
 //===----------------------------------------------------------------------===//
@@ -849,49 +931,6 @@ size_t TupleType::size() const { return getImpl()->size(); }
 // Type Utilities
 //===----------------------------------------------------------------------===//
 
-/// Return a version of `t` with identity layout if it can be determined
-/// statically that the layout is the canonical contiguous strided layout.
-/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
-/// `t` with simplified layout.
-/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
-MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
-  AffineMap m = t.getLayout().getAffineMap();
-
-  // Already in canonical form.
-  if (m.isIdentity())
-    return t;
-
-  // Can't reduce to canonical identity form, return in canonical form.
-  if (m.getNumResults() > 1)
-    return t;
-
-  // Corner-case for 0-D affine maps.
-  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
-    if (auto cst = dyn_cast<AffineConstantExpr>(m.getResult(0)))
-      if (cst.getValue() == 0)
-        return MemRefType::Builder(t).setLayout({});
-    return t;
-  }
-
-  // 0-D corner case for empty shape that still have an affine map. Example:
-  // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
-  // offset needs to remain, just return t.
-  if (t.getShape().empty())
-    return t;
-
-  // If the canonical strided layout for the sizes of `t` is equal to the
-  // simplified layout of `t` we can just return an empty layout. Otherwise,
-  // just simplify the existing layout.
-  AffineExpr expr =
-      makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
-  auto simplifiedLayoutExpr =
-      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
-  if (expr != simplifiedLayoutExpr)
-    return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
-        m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
-  return MemRefType::Builder(t).setLayout({});
-}
-
 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
                                                 ArrayRef<AffineExpr> exprs,
                                                 MLIRContext *context) {
@@ -932,49 +971,3 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
     exprs.push_back(getAffineDimExpr(dim, context));
   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
 }
-
-bool mlir::isStrided(MemRefType t) {
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  auto res = getStridesAndOffset(t, strides, offset);
-  return succeeded(res);
-}
-
-bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
-  int64_t offset;
-  SmallVector<int64_t> strides;
-  auto successStrides = getStridesAndOffset(type, strides, offset);
-  return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
-}
-
-bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
-  if (!isLastMemrefDimUnitStride(type))
-    return false;
-
-  auto memrefShape = type.getShape().take_back(n);
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
-
-  if (type.getLayout().isIdentity())
-    return true;
-
-  int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
-
-  if (strides.empty())
-    return true;
-
-  // Check whether strides match "flattened" dims.
-  SmallVector<int64_t> flattenedDims;
-  auto dimProduct = 1;
-  for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
-    dimProduct *= dim;
-    flattenedDims.push_back(dimProduct);
-  }
-
-  strides = strides.drop_back(1);
-  return llvm::equal(strides, llvm::reverse(flattenedDims));
-}
diff --git a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
index 968e10b8d0cab6..f17f5db2fa22fe 100644
--- a/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
+++ b/mlir/test/lib/Analysis/TestMemRefStrideCalculation.cpp
@@ -35,7 +35,7 @@ void TestMemRefStrideCalculation::runOnOperation() {
     auto memrefType = cast<MemRefType>(allocOp.getResult().getType());
     int64_t offset;
     SmallVector<int64_t, 4> strides;
-    if (failed(getStridesAndOffset(memrefType, strides, offset))) {
+    if (failed(memrefType.getStridesAndOffset(strides, offset))) {
       llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
                    << "strided form\n";
       return;



More information about the Mlir-commits mailing list