[Mlir-commits] [mlir] 640103b - [mlir][MemRef][~NFC] Move getStridesAndOffset() onto layouts (#138011)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 5 09:09:35 PDT 2025
Author: Krzysztof Drewniak
Date: 2025-05-05T09:09:32-07:00
New Revision: 640103b91ac892cfbeeb614495698c321437b567
URL: https://github.com/llvm/llvm-project/commit/640103b91ac892cfbeeb614495698c321437b567
DIFF: https://github.com/llvm/llvm-project/commit/640103b91ac892cfbeeb614495698c321437b567.diff
LOG: [mlir][MemRef][~NFC] Move getStridesAndOffset() onto layouts (#138011)
This commit refactors the getStridesAndOffet() method on MemRefType to
just call `MemRefLayoutAttrInterface::getStridesAndOffset(shape,
strides& offset&)`, allowing downstream users and future layouts (ex, a
potential contiguous layout) to implement it without needing to patch
BuiltinTypes or without needing them to conform their affine maps to the
canonical strided form.
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/IR/BuiltinAttributeInterfaces.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index b969b60a66f16..b94a933b5c945 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -270,6 +270,12 @@ LogicalResult
verifyAffineMapAsLayout(AffineMap m, ArrayRef<int64_t> shape,
function_ref<InFlightDiagnostic()> emitError);
+// Return the strides and offsets that can be inferred from the given affine
+// layout map given the map and a memref shape.
+LogicalResult getAffineMapStridesAndOffset(AffineMap map,
+ ArrayRef<int64_t> shape,
+ SmallVectorImpl<int64_t> &strides,
+ int64_t &offset);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 6220d80264bdf..cf9697457f4d8 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -509,6 +509,23 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
return ::mlir::detail::verifyAffineMapAsLayout($_attr.getAffineMap(),
shape, emitError);
}]
+ >,
+
+ InterfaceMethod<
+ [{Return the strides (using ShapedType::kDynamic for the dynamic case)
+ that this layout corresponds to into `strides` and `offset` if such exist
+ and can be determined from a combination of the layout and the given
+ `shape`. If these strides cannot be inferred, return failure().
+ The values of `strides` and `offset` are undefined on failure.}],
+ "::llvm::LogicalResult", "getStridesAndOffset",
+ (ins "::llvm::ArrayRef<int64_t>":$shape,
+ "::llvm::SmallVectorImpl<int64_t>&":$strides,
+ "int64_t&":$offset),
+ [{}],
+ [{
+ return ::mlir::detail::getAffineMapStridesAndOffset(
+ $_attr.getAffineMap(), shape, strides, offset);
+ }]
>
];
}
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0169f4b38bbe0..854a24ab8605c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -1003,7 +1003,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr<
def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface,
- ["verifyLayout"]>]> {
+ ["verifyLayout", "getStridesAndOffset"]>]> {
let summary = "An Attribute representing a strided layout of a shaped type";
let description = [{
Syntax:
diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
index 9b5235a6c5ceb..9e8ce4ca3a902 100644
--- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp
@@ -83,3 +83,138 @@ LogicalResult mlir::detail::verifyAffineMapAsLayout(
return success();
}
+
+// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
+// i.e. single term). Accumulate the AffineExpr into the existing one.
+static void extractStridesFromTerm(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
+ if (auto dim = dyn_cast<AffineDimExpr>(e))
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + multiplicativeFactor;
+ else
+ offset = offset + e * multiplicativeFactor;
+}
+
+/// Takes a single AffineExpr `e` and populates the `strides` array with the
+/// strides expressions for each dim position.
+/// The convention is that the strides for dimensions d0, .. dn appear in
+/// order to make indexing intuitive into the result.
+static LogicalResult extractStrides(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
+ auto bin = dyn_cast<AffineBinaryOpExpr>(e);
+ if (!bin) {
+ extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
+ return success();
+ }
+
+ if (bin.getKind() == AffineExprKind::CeilDiv ||
+ bin.getKind() == AffineExprKind::FloorDiv ||
+ bin.getKind() == AffineExprKind::Mod)
+ return failure();
+
+ if (bin.getKind() == AffineExprKind::Mul) {
+ auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
+ if (dim) {
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
+ return success();
+ }
+ // LHS and RHS may both contain complex expressions of dims. Try one path
+ // and if it fails try the other. This is guaranteed to succeed because
+ // only one path may have a `dim`, otherwise this is not an AffineExpr in
+ // the first place.
+ if (bin.getLHS().isSymbolicOrConstant())
+ return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
+ strides, offset);
+ return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
+ strides, offset);
+ }
+
+ if (bin.getKind() == AffineExprKind::Add) {
+ auto res1 =
+ extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
+ auto res2 =
+ extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
+ return success(succeeded(res1) && succeeded(res2));
+ }
+
+ llvm_unreachable("unexpected binary operation");
+}
+
+/// A stride specification is a list of integer values that are either static
+/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
+/// the distance in the number of elements between successive entries along a
+/// particular dimension.
+///
+/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
+/// non-contiguous memory region of `42` by `16` `f32` elements in which the
+/// distance between two consecutive elements along the outer dimension is `1`
+/// and the distance between two consecutive elements along the inner dimension
+/// is `64`.
+///
+/// The convention is that the strides for dimensions d0, .. dn appear in
+/// order to make indexing intuitive into the result.
+static LogicalResult getStridesAndOffset(AffineMap m, ArrayRef<int64_t> shape,
+ SmallVectorImpl<AffineExpr> &strides,
+ AffineExpr &offset) {
+ if (m.getNumResults() != 1 && !m.isIdentity())
+ return failure();
+
+ auto zero = getAffineConstantExpr(0, m.getContext());
+ auto one = getAffineConstantExpr(1, m.getContext());
+ offset = zero;
+ strides.assign(shape.size(), zero);
+
+ // Canonical case for empty map.
+ if (m.isIdentity()) {
+ // 0-D corner case, offset is already 0.
+ if (shape.empty())
+ return success();
+ auto stridedExpr = makeCanonicalStridedLayoutExpr(shape, m.getContext());
+ if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
+ return success();
+ assert(false && "unexpected failure: extract strides in canonical layout");
+ }
+
+ // Non-canonical case requires more work.
+ auto stridedExpr =
+ simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+ if (failed(extractStrides(stridedExpr, one, strides, offset))) {
+ offset = AffineExpr();
+ strides.clear();
+ return failure();
+ }
+
+ // Simplify results to allow folding to constants and simple checks.
+ unsigned numDims = m.getNumDims();
+ unsigned numSymbols = m.getNumSymbols();
+ offset = simplifyAffineExpr(offset, numDims, numSymbols);
+ for (auto &stride : strides)
+ stride = simplifyAffineExpr(stride, numDims, numSymbols);
+
+ return success();
+}
+
+LogicalResult mlir::detail::getAffineMapStridesAndOffset(
+ AffineMap map, ArrayRef<int64_t> shape, SmallVectorImpl<int64_t> &strides,
+ int64_t &offset) {
+ AffineExpr offsetExpr;
+ SmallVector<AffineExpr, 4> strideExprs;
+ if (failed(::getStridesAndOffset(map, shape, strideExprs, offsetExpr)))
+ return failure();
+ if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
+ offset = cst.getValue();
+ else
+ offset = ShapedType::kDynamic;
+ for (auto e : strideExprs) {
+ if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
+ strides.push_back(c.getValue());
+ else
+ strides.push_back(ShapedType::kDynamic);
+ }
+ return success();
+}
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index e9af1f77a379e..617dcc222cd6e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -258,6 +258,15 @@ LogicalResult StridedLayoutAttr::verifyLayout(
return success();
}
+LogicalResult
+StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
+ SmallVectorImpl<int64_t> &strides,
+ int64_t &offset) const {
+ llvm::append_range(strides, getStrides());
+ offset = getOffset();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..d47e360e9dc13 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -715,150 +715,9 @@ MemRefType MemRefType::canonicalizeStridedLayout() {
return MemRefType::Builder(*this).setLayout({});
}
-// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
-// i.e. single term). Accumulate the AffineExpr into the existing one.
-static void extractStridesFromTerm(AffineExpr e,
- AffineExpr multiplicativeFactor,
- MutableArrayRef<AffineExpr> strides,
- AffineExpr &offset) {
- if (auto dim = dyn_cast<AffineDimExpr>(e))
- strides[dim.getPosition()] =
- strides[dim.getPosition()] + multiplicativeFactor;
- else
- offset = offset + e * multiplicativeFactor;
-}
-
-/// Takes a single AffineExpr `e` and populates the `strides` array with the
-/// strides expressions for each dim position.
-/// The convention is that the strides for dimensions d0, .. dn appear in
-/// order to make indexing intuitive into the result.
-static LogicalResult extractStrides(AffineExpr e,
- AffineExpr multiplicativeFactor,
- MutableArrayRef<AffineExpr> strides,
- AffineExpr &offset) {
- auto bin = dyn_cast<AffineBinaryOpExpr>(e);
- if (!bin) {
- extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
- return success();
- }
-
- if (bin.getKind() == AffineExprKind::CeilDiv ||
- bin.getKind() == AffineExprKind::FloorDiv ||
- bin.getKind() == AffineExprKind::Mod)
- return failure();
-
- if (bin.getKind() == AffineExprKind::Mul) {
- auto dim = dyn_cast<AffineDimExpr>(bin.getLHS());
- if (dim) {
- strides[dim.getPosition()] =
- strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
- return success();
- }
- // LHS and RHS may both contain complex expressions of dims. Try one path
- // and if it fails try the other. This is guaranteed to succeed because
- // only one path may have a `dim`, otherwise this is not an AffineExpr in
- // the first place.
- if (bin.getLHS().isSymbolicOrConstant())
- return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
- strides, offset);
- return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
- strides, offset);
- }
-
- if (bin.getKind() == AffineExprKind::Add) {
- auto res1 =
- extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
- auto res2 =
- extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
- return success(succeeded(res1) && succeeded(res2));
- }
-
- llvm_unreachable("unexpected binary operation");
-}
-
-/// A stride specification is a list of integer values that are either static
-/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
-/// the distance in the number of elements between successive entries along a
-/// particular dimension.
-///
-/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
-/// non-contiguous memory region of `42` by `16` `f32` elements in which the
-/// distance between two consecutive elements along the outer dimension is `1`
-/// and the distance between two consecutive elements along the inner dimension
-/// is `64`.
-///
-/// The convention is that the strides for dimensions d0, .. dn appear in
-/// order to make indexing intuitive into the result.
-static LogicalResult getStridesAndOffset(MemRefType t,
- SmallVectorImpl<AffineExpr> &strides,
- AffineExpr &offset) {
- AffineMap m = t.getLayout().getAffineMap();
-
- if (m.getNumResults() != 1 && !m.isIdentity())
- return failure();
-
- auto zero = getAffineConstantExpr(0, t.getContext());
- auto one = getAffineConstantExpr(1, t.getContext());
- offset = zero;
- strides.assign(t.getRank(), zero);
-
- // Canonical case for empty map.
- if (m.isIdentity()) {
- // 0-D corner case, offset is already 0.
- if (t.getRank() == 0)
- return success();
- auto stridedExpr =
- makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
- if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
- return success();
- assert(false && "unexpected failure: extract strides in canonical layout");
- }
-
- // Non-canonical case requires more work.
- auto stridedExpr =
- simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
- if (failed(extractStrides(stridedExpr, one, strides, offset))) {
- offset = AffineExpr();
- strides.clear();
- return failure();
- }
-
- // Simplify results to allow folding to constants and simple checks.
- unsigned numDims = m.getNumDims();
- unsigned numSymbols = m.getNumSymbols();
- offset = simplifyAffineExpr(offset, numDims, numSymbols);
- for (auto &stride : strides)
- stride = simplifyAffineExpr(stride, numDims, numSymbols);
-
- return success();
-}
-
LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
- // Happy path: the type uses the strided layout directly.
- if (auto strided = llvm::dyn_cast<StridedLayoutAttr>(getLayout())) {
- llvm::append_range(strides, strided.getStrides());
- offset = strided.getOffset();
- return success();
- }
-
- // Otherwise, defer to the affine fallback as layouts are supposed to be
- // convertible to affine maps.
- AffineExpr offsetExpr;
- SmallVector<AffineExpr, 4> strideExprs;
- if (failed(::getStridesAndOffset(*this, strideExprs, offsetExpr)))
- return failure();
- if (auto cst = llvm::dyn_cast<AffineConstantExpr>(offsetExpr))
- offset = cst.getValue();
- else
- offset = ShapedType::kDynamic;
- for (auto e : strideExprs) {
- if (auto c = llvm::dyn_cast<AffineConstantExpr>(e))
- strides.push_back(c.getValue());
- else
- strides.push_back(ShapedType::kDynamic);
- }
- return success();
+ return getLayout().getStridesAndOffset(getShape(), strides, offset);
}
std::pair<SmallVector<int64_t>, int64_t> MemRefType::getStridesAndOffset() {
More information about the Mlir-commits
mailing list