[Mlir-commits] [mlir] [mlir][MemRef][~NFC] Move getStridesAndOffset() onto layouts (PR #138011)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Apr 30 11:28:36 PDT 2025


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/138011

This commit refactors the getStridesAndOffet() method on MemRefType to just call `MemRefLayoutAttrInterface::getStridesAndOffset(shape, strides& offset&)`, allowing downstream users and future layouts (ex, a potential contiguous layout) to implement it without needing to patch BuiltinTypes or without needing them to conform their affine maps to the canonical strided form.

>From 4a0b1cc77f12307ab7717079c9aaa3bab8da78d9 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 30 Apr 2025 18:14:42 +0000
Subject: [PATCH] [mlir][MemRef][~NFC] Move getStridesAndOffset() onto layouts

This commit refactors the getStridesAndOffet() method on MemRefType to
just call `MemRefLayoutAttrInterface::getStridesAndOffset(shape,
strides& offset&)`, allowing downstream users and future layouts (ex,
a potential contiguous layout) to implement it without needing to
patch BuiltinTypes or without needing them to conform their affine
maps to the canonical strided form.
---
 .../mlir/IR/BuiltinAttributeInterfaces.h      |   6 +
 .../mlir/IR/BuiltinAttributeInterfaces.td     |  17 +++
 mlir/include/mlir/IR/BuiltinAttributes.td     |   2 +-
 mlir/lib/IR/BuiltinAttributeInterfaces.cpp    | 135 +++++++++++++++++
 mlir/lib/IR/BuiltinAttributes.cpp             |   9 ++
 mlir/lib/IR/BuiltinTypes.cpp                  | 143 +-----------------
 6 files changed, 169 insertions(+), 143 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index b969b60a66f16..b94a933b5c945 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -270,6 +270,12 @@ LogicalResult
 verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
                         function_ref<InFlightDiagnostic()> emitError);
 
+// Return the strides and offsets that can be inferred from the given affine
+// layout map given the map and a memref shape.
+LogicalResult getAffineMapStridesAndOffset(AffineMap map,
+                                           ArrayRef<int64_t> shape,
+                                           SmallVectorImpl<int64_t> &strides,
+                                           int64_t &offset);
 } // namespace detail
 
 } // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 6220d80264bdf..fd757578eb1f2 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -509,6 +509,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
         return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
                                                        shape, emitError);
       }]
+    >,
+
+    InterfaceMethod<
+      [{Return the strides (using ShapedType::kDynamic for the dynamic case)
+      that this layout corresponds to into `strides` and `offset` if such exist
+      and can be detirmined from a combination of the layout and the given
+      `shape`. If these strides cannot be inferred, return failure().
+      The values of `strides` and `offset` are undefined on failure.}],
+      "::llvm::LogicalResult", "getStridesAndOffset",
+      (ins "::llvm::ArrayRef<int64_t>":$shape,
+           "::llvm::SmallVectorImpl<int64_t>&":$strides,
+           "int64_t&":$offset),
+           [{}],
+           [{
+            return ::mlir::detail::getAffineMapStridesAndOffset(
+              $_attr.getAffineMap(), shape, strides, offset);
+           }]
     >
   ];
 }
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0169f4b38bbe0..854a24ab8605c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1003,7 +1003,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
 
 def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
     [DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface,
-                                 ["verifyLayout"]>]> {
+                                 ["verifyLayout", "getStridesAndOffset"]>]> {
   let summary = "An Attribute representing a strided layout of a shaped type";
   let description = [{
     Syntax:
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 9b5235a6c5ceb..9e8ce4ca3a902 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -83,3 +83,138 @@ LogicalResult mlir::detail::verifyAffineMapAsLayout(
 
   return success();
 }
+
+// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
+// i.e. single term). Accumulate the AffineExpr into the existing one.
+static void extractStridesFromTerm(AffineExpr e,
+                                   AffineExpr multiplicativeFactor,
+                                   MutableArrayRef<AffineExpr> strides,
+                                   AffineExpr &offset) {
+  if (auto dim = dyn_cast<AffineDimExpr>(e))
+    strides[dim.getPosition()] =
+        strides[dim.getPosition()] + multiplicativeFactor;
+  else
+    offset = offset + e * multiplicativeFactor;
+}
+
+/// Takes a single AffineExpr `e` and populates the `strides` array with the
+/// strides expressions for each dim position.
+/// The convention is that the strides for dimensions d0, .. dn appear in
+/// order to make indexing intuitive into the result.
+static LogicalResult extractStrides(AffineExpr e,
+                                    AffineExpr multiplicativeFactor,
+                                    MutableArrayRef<AffineExpr> strides,
+                                    AffineExpr &offset) {
+  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
+  if (!bin) {
+    extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
+    return success();
+  }
+
+  if (bin.getKind() == AffineExprKind::CeilDiv ||
+      bin.getKind() == AffineExprKind::FloorDiv ||
+      bin.getKind() == AffineExprKind::Mod)
+    return failure();
+
+  if (bin.getKind() == AffineExprKind::Mul) {
+    auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
+    if (dim) {
+      strides[dim.getPosition()] =
+          strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
+      return success();
+    }
+    // LHS and RHS may both contain complex expressions of dims. Try one path
+    // and if it fails try the other. This is guaranteed to succeed because
+    // only one path may have a `dim`, otherwise this is not an AffineExpr in
+    // the first place.
+    if (bin.getLHS().isSymbolicOrConstant())
+      return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
+                            strides, offset);
+    return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
+                          strides, offset);
+  }
+
+  if (bin.getKind() == AffineExprKind::Add) {
+    auto res1 =
+        extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
+    auto res2 =
+        extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
+    return success(succeeded(res1) && succeeded(res2));
+  }
+
+  llvm_unreachable("unexpected binary operation");
+}
+
+/// A stride specification is a list of integer values that are either static
+/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
+/// the distance in the number of elements between successive entries along a
+/// particular dimension.
+///
+/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
+/// non-contiguous memory region of `42` by `16` `f32` elements in which the
+/// distance between two consecutive elements along the outer dimension is `1`
+/// and the distance between two consecutive elements along the inner dimension
+/// is `64`.
+///
+/// The convention is that the strides for dimensions d0, .. dn appear in
+/// order to make indexing intuitive into the result.
+static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
+                                         SmallVectorImpl<AffineExpr> &strides,
+                                         AffineExpr &offset) {
+  if (m.getNumResults() != 1 && !m.isIdentity())
+    return failure();
+
+  auto zero = getAffineConstantExpr(0, m.getContext());
+  auto one = getAffineConstantExpr(1, m.getContext());
+  offset = zero;
+  strides.assign(shape.size(), zero);
+
+  // Canonical case for empty map.
+  if (m.isIdentity()) {
+    // 0-D corner case, offset is already 0.
+    if (shape.empty())
+      return success();
+    auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
+    if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
+      return success();
+    assert(false && "unexpected failure: extract strides in canonical layout");
+  }
+
+  // Non-canonical case requires more work.
+  auto stridedExpr =
+      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
+    offset = AffineExpr();
+    strides.clear();
+    return failure();
+  }
+
+  // Simplify results to allow folding to constants and simple checks.
+  unsigned numDims = m.getNumDims();
+  unsigned numSymbols = m.getNumSymbols();
+  offset = simplifyAffineExpr(offset, numDims, numSymbols);
+  for (auto &stride : strides)
+    stride = simplifyAffineExpr(stride, numDims, numSymbols);
+
+  return success();
+}
+
+LogicalResult mlir::detail::getAffineMapStridesAndOffset(
+    AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
+    int64_t &offset) {
+  AffineExpr offsetExpr;
+  SmallVector<AffineExpr, 4> strideExprs;
+  if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
+    return failure();
+  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
+    offset = cst.getValue();
+  else
+    offset = ShapedType::kDynamic;
+  for (auto e : strideExprs) {
+    if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
+      strides.push_back(c.getValue());
+    else
+      strides.push_back(ShapedType::kDynamic);
+  }
+  return success();
+}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e9af1f77a379e..617dcc222cd6e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -258,6 +258,15 @@ LogicalResult StridedLayoutAttr::verifyLayout(
   return success();
 }
 
+LogicalResult
+StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
+                                       SmallVectorImpl<int64_t> &strides,
+                                       int64_t &offset) const {
+  llvm::append_range(strides, getStrides());
+  offset = getOffset();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..d47e360e9dc13 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
   return MemRefType::Builder(*this).setLayout({});
 }
 
-// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
-// i.e. single term). Accumulate the AffineExpr into the existing one.
-static void extractStridesFromTerm(AffineExpr e,
-                                   AffineExpr multiplicativeFactor,
-                                   MutableArrayRef<AffineExpr> strides,
-                                   AffineExpr &offset) {
-  if (auto dim = dyn_cast<AffineDimExpr>(e))
-    strides[dim.getPosition()] =
-        strides[dim.getPosition()] + multiplicativeFactor;
-  else
-    offset = offset + e * multiplicativeFactor;
-}
-
-/// Takes a single AffineExpr `e` and populates the `strides` array with the
-/// strides expressions for each dim position.
-/// The convention is that the strides for dimensions d0, .. dn appear in
-/// order to make indexing intuitive into the result.
-static LogicalResult extractStrides(AffineExpr e,
-                                    AffineExpr multiplicativeFactor,
-                                    MutableArrayRef<AffineExpr> strides,
-                                    AffineExpr &offset) {
-  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
-  if (!bin) {
-    extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
-    return success();
-  }
-
-  if (bin.getKind() == AffineExprKind::CeilDiv ||
-      bin.getKind() == AffineExprKind::FloorDiv ||
-      bin.getKind() == AffineExprKind::Mod)
-    return failure();
-
-  if (bin.getKind() == AffineExprKind::Mul) {
-    auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
-    if (dim) {
-      strides[dim.getPosition()] =
-          strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
-      return success();
-    }
-    // LHS and RHS may both contain complex expressions of dims. Try one path
-    // and if it fails try the other. This is guaranteed to succeed because
-    // only one path may have a `dim`, otherwise this is not an AffineExpr in
-    // the first place.
-    if (bin.getLHS().isSymbolicOrConstant())
-      return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
-                            strides, offset);
-    return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
-                          strides, offset);
-  }
-
-  if (bin.getKind() == AffineExprKind::Add) {
-    auto res1 =
-        extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
-    auto res2 =
-        extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
-    return success(succeeded(res1) && succeeded(res2));
-  }
-
-  llvm_unreachable("unexpected binary operation");
-}
-
-/// A stride specification is a list of integer values that are either static
-/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
-/// the distance in the number of elements between successive entries along a
-/// particular dimension.
-///
-/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
-/// non-contiguous memory region of `42` by `16` `f32` elements in which the
-/// distance between two consecutive elements along the outer dimension is `1`
-/// and the distance between two consecutive elements along the inner dimension
-/// is `64`.
-///
-/// The convention is that the strides for dimensions d0, .. dn appear in
-/// order to make indexing intuitive into the result.
-static LogicalResult getStridesAndOffset(MemRefType t,
-                                         SmallVectorImpl<AffineExpr> &strides,
-                                         AffineExpr &offset) {
-  AffineMap m = t.getLayout().getAffineMap();
-
-  if (m.getNumResults() != 1 && !m.isIdentity())
-    return failure();
-
-  auto zero = getAffineConstantExpr(0, t.getContext());
-  auto one = getAffineConstantExpr(1, t.getContext());
-  offset = zero;
-  strides.assign(t.getRank(), zero);
-
-  // Canonical case for empty map.
-  if (m.isIdentity()) {
-    // 0-D corner case, offset is already 0.
-    if (t.getRank() == 0)
-      return success();
-    auto stridedExpr =
-        makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
-    if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
-      return success();
-    assert(false && "unexpected failure: extract strides in canonical layout");
-  }
-
-  // Non-canonical case requires more work.
-  auto stridedExpr =
-      simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
-  if (failed(extractStrides(stridedExpr, one, strides, offset))) {
-    offset = AffineExpr();
-    strides.clear();
-    return failure();
-  }
-
-  // Simplify results to allow folding to constants and simple checks.
-  unsigned numDims = m.getNumDims();
-  unsigned numSymbols = m.getNumSymbols();
-  offset = simplifyAffineExpr(offset, numDims, numSymbols);
-  for (auto &stride : strides)
-    stride = simplifyAffineExpr(stride, numDims, numSymbols);
-
-  return success();
-}
-
 LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
                                               int64_t &offset) {
-  // Happy path: the type uses the strided layout directly.
-  if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
-    llvm::append_range(strides, strided.getStrides());
-    offset = strided.getOffset();
-    return success();
-  }
-
-  // Otherwise, defer to the affine fallback as layouts are supposed to be
-  // convertible to affine maps.
-  AffineExpr offsetExpr;
-  SmallVector<AffineExpr, 4> strideExprs;
-  if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
-    return failure();
-  if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
-    offset = cst.getValue();
-  else
-    offset = ShapedType::kDynamic;
-  for (auto e : strideExprs) {
-    if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
-      strides.push_back(c.getValue());
-    else
-      strides.push_back(ShapedType::kDynamic);
-  }
-  return success();
+  return getLayout().getStridesAndOffset(getShape(), strides, offset);
 }
 
 std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {



More information about the Mlir-commits mailing list