[Mlir-commits] [mlir] [mlir] Use identity map to construct the memref type as possible (PR #183051)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 24 04:46:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: donald chen (cxy-1993)
<details>
<summary>Changes</summary>
Many optimizations and analyses in MLIR rely on the memref type being an identity. Therefore, we should use an identity map whenever possible when constructing memref types. This patch checks the continuity of the input shape and strides to convert the input map into an identity map as much as possible, thereby constructing the memref type.
---
Full diff: https://github.com/llvm/llvm-project/pull/183051.diff
1 Files Affected:
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+105-34)
``````````diff
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1e198043c590a..2774f194ee3d5 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -566,13 +566,46 @@ unsigned MemRefType::getMemorySpaceAsInt() const {
return detail::getMemorySpaceAsInt(getMemorySpace());
}
+static bool getNumContiguousTrailingDimsImpl(ArrayRef<int64_t> shape,
+ ArrayRef<int64_t> strides) {
+ const int64_t n = shape.size();
+ // A memref with dimensions `d0, d1, ..., dn-1` and strides
+ // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+ // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+ // for `i` in `[k, n-1]`.
+ // Ignore stride elements if the corresponding dimension is 1, as they are
+ // of no consequence.
+ int64_t dimProduct = 1;
+ for (int64_t i = n - 1; i >= 0; --i) {
+ if (shape[i] == 1)
+ continue;
+ if (strides[i] != dimProduct)
+ return n - i - 1;
+ if (shape[i] == ShapedType::kDynamic)
+ return n - i;
+ dimProduct *= shape[i];
+ }
+ return n;
+}
+
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
- // Use default layout for empty attribute.
- if (!layout)
+ if (!layout) {
+ // Use default layout for empty attribute.
layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
shape.size(), elementType.getContext()));
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)layout.getStridesAndOffset(shape, strides, offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+ shape.size(), elementType.getContext()));
+ }
// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -585,10 +618,21 @@ MemRefType MemRefType::getChecked(
function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
- // Use default layout for empty attribute.
- if (!layout)
+ if (!layout) {
+ // Use default layout for empty attribute.
layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
shape.size(), elementType.getContext()));
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)layout.getStridesAndOffset(shape, strides, offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+ shape.size(), elementType.getContext()));
+ }
// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -600,10 +644,22 @@ MemRefType MemRefType::getChecked(
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
AffineMap map, Attribute memorySpace) {
- // Use default layout for empty map.
- if (!map)
+ if (!map) {
+ // Use default layout for empty map.
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+ offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+ }
// Wrap AffineMap into Attribute.
auto layout = AffineMapAttr::get(map);
@@ -620,10 +676,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
ArrayRef<int64_t> shape, Type elementType, AffineMap map,
Attribute memorySpace) {
- // Use default layout for empty map.
- if (!map)
+ if (!map) {
+ // Use default layout for empty map.
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+ offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+ }
// Wrap AffineMap into Attribute.
auto layout = AffineMapAttr::get(map);
@@ -638,10 +706,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
AffineMap map, unsigned memorySpaceInd) {
- // Use default layout for empty map.
- if (!map)
+ if (!map) {
+ // Use default layout for empty map.
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+ offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+ }
// Wrap AffineMap into Attribute.
auto layout = AffineMapAttr::get(map);
@@ -659,10 +739,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
ArrayRef<int64_t> shape, Type elementType, AffineMap map,
unsigned memorySpaceInd) {
- // Use default layout for empty map.
- if (!map)
+ if (!map) {
+ // Use default layout for empty map.
map = AffineMap::getMultiDimIdentityMap(shape.size(),
elementType.getContext());
+ } else {
+ // If the layout can be inferred to be an identity, prefer using the
+ // identity layout.
+ int64_t offset;
+ SmallVector<int64_t> strides;
+ (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+ offset);
+ if (offset == 0 &&
+ getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+ map = AffineMap::getMultiDimIdentityMap(shape.size(),
+ elementType.getContext());
+ }
// Wrap AffineMap into Attribute.
auto layout = AffineMapAttr::get(map);
@@ -705,7 +797,6 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
int64_t MemRefType::getNumContiguousTrailingDims() {
const int64_t n = getRank();
-
// memrefs with identity layout are entirely contiguous.
if (getLayout().isIdentity())
return n;
@@ -716,27 +807,7 @@ int64_t MemRefType::getNumContiguousTrailingDims() {
SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(strides, offset)))
return 0;
-
- ArrayRef<int64_t> shape = getShape();
-
- // A memref with dimensions `d0, d1, ..., dn-1` and strides
- // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
- // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
- // for `i` in `[k, n-1]`.
- // Ignore stride elements if the corresponding dimension is 1, as they are
- // of no consequence.
- int64_t dimProduct = 1;
- for (int64_t i = n - 1; i >= 0; --i) {
- if (shape[i] == 1)
- continue;
- if (strides[i] != dimProduct)
- return n - i - 1;
- if (shape[i] == ShapedType::kDynamic)
- return n - i;
- dimProduct *= shape[i];
- }
-
- return n;
+ return getNumContiguousTrailingDimsImpl(getShape(), strides);
}
MemRefType MemRefType::canonicalizeStridedLayout() {
``````````
</details>
https://github.com/llvm/llvm-project/pull/183051
More information about the Mlir-commits
mailing list