[Mlir-commits] [mlir] [mlir] Use identity map to construct the memref type as possible (PR #183051)

donald chen llvmlistbot at llvm.org
Tue Feb 24 04:46:18 PST 2026


https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/183051

Many optimizations and analyses in MLIR rely on the memref type being an identity. Therefore, we should use an identity map whenever possible when constructing memref types. This patch checks the continuity of the input shape and strides to convert the input map into an identity map as much as possible, thereby constructing the memref type.

>From def8c05fa6e7c12c92de838593cc0e3f6ee9f971 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Tue, 24 Feb 2026 12:28:52 +0000
Subject: [PATCH] [mlir] Use identity map to construct the memref type as
 possible

Many optimizations and analyses in MLIR rely on the memref type being an identity.
Therefore, we should use an identity map whenever possible when constructing
memref types. This patch checks the continuity of the input shape and strides
to convert the input map into an identity map as much as possible, thereby
constructing the memref type.
---
 mlir/lib/IR/BuiltinTypes.cpp | 139 ++++++++++++++++++++++++++---------
 1 file changed, 105 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 1e198043c590a..2774f194ee3d5 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -566,13 +566,46 @@ unsigned MemRefType::getMemorySpaceAsInt() const {
   return detail::getMemorySpaceAsInt(getMemorySpace());
 }
 
+static bool getNumContiguousTrailingDimsImpl(ArrayRef<int64_t> shape,
+                                             ArrayRef<int64_t> strides) {
+  const int64_t n = shape.size();
+  // A memref with dimensions `d0, d1, ..., dn-1` and strides
+  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
+  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
+  // for `i` in `[k, n-1]`.
+  // Ignore stride elements if the corresponding dimension is 1, as they are
+  // of no consequence.
+  int64_t dimProduct = 1;
+  for (int64_t i = n - 1; i >= 0; --i) {
+    if (shape[i] == 1)
+      continue;
+    if (strides[i] != dimProduct)
+      return n - i - 1;
+    if (shape[i] == ShapedType::kDynamic)
+      return n - i;
+    dimProduct *= shape[i];
+  }
+  return n;
+}
+
 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                            MemRefLayoutAttrInterface layout,
                            Attribute memorySpace) {
-  // Use default layout for empty attribute.
-  if (!layout)
+  if (!layout) {
+    // Use default layout for empty attribute.
     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
         shape.size(), elementType.getContext()));
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)layout.getStridesAndOffset(shape, strides, offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+          shape.size(), elementType.getContext()));
+  }
 
   // Drop default memory space value and replace it with empty attribute.
   memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -585,10 +618,21 @@ MemRefType MemRefType::getChecked(
     function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
     Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
 
-  // Use default layout for empty attribute.
-  if (!layout)
+  if (!layout) {
+    // Use default layout for empty attribute.
     layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
         shape.size(), elementType.getContext()));
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)layout.getStridesAndOffset(shape, strides, offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
+          shape.size(), elementType.getContext()));
+  }
 
   // Drop default memory space value and replace it with empty attribute.
   memorySpace = skipDefaultMemorySpace(memorySpace);
@@ -600,10 +644,22 @@ MemRefType MemRefType::getChecked(
 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                            AffineMap map, Attribute memorySpace) {
 
-  // Use default layout for empty map.
-  if (!map)
+  if (!map) {
+    // Use default layout for empty map.
     map = AffineMap::getMultiDimIdentityMap(shape.size(),
                                             elementType.getContext());
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+                                                       offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                              elementType.getContext());
+  }
 
   // Wrap AffineMap into Attribute.
   auto layout = AffineMapAttr::get(map);
@@ -620,10 +676,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
                        Attribute memorySpace) {
 
-  // Use default layout for empty map.
-  if (!map)
+  if (!map) {
+    // Use default layout for empty map.
     map = AffineMap::getMultiDimIdentityMap(shape.size(),
                                             elementType.getContext());
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+                                                       offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                              elementType.getContext());
+  }
 
   // Wrap AffineMap into Attribute.
   auto layout = AffineMapAttr::get(map);
@@ -638,10 +706,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
                            AffineMap map, unsigned memorySpaceInd) {
 
-  // Use default layout for empty map.
-  if (!map)
+  if (!map) {
+    // Use default layout for empty map.
     map = AffineMap::getMultiDimIdentityMap(shape.size(),
                                             elementType.getContext());
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+                                                       offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                              elementType.getContext());
+  }
 
   // Wrap AffineMap into Attribute.
   auto layout = AffineMapAttr::get(map);
@@ -659,10 +739,22 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                        ArrayRef<int64_t> shape, Type elementType, AffineMap map,
                        unsigned memorySpaceInd) {
 
-  // Use default layout for empty map.
-  if (!map)
+  if (!map) {
+    // Use default layout for empty map.
     map = AffineMap::getMultiDimIdentityMap(shape.size(),
                                             elementType.getContext());
+  } else {
+    // If the layout can be inferred to be an identity, prefer using the
+    // identity layout.
+    int64_t offset;
+    SmallVector<int64_t> strides;
+    (void)::mlir::detail::getAffineMapStridesAndOffset(map, shape, strides,
+                                                       offset);
+    if (offset == 0 &&
+        getNumContiguousTrailingDimsImpl(shape, strides) == shape.size())
+      map = AffineMap::getMultiDimIdentityMap(shape.size(),
+                                              elementType.getContext());
+  }
 
   // Wrap AffineMap into Attribute.
   auto layout = AffineMapAttr::get(map);
@@ -705,7 +797,6 @@ bool MemRefType::areTrailingDimsContiguous(int64_t n) {
 
 int64_t MemRefType::getNumContiguousTrailingDims() {
   const int64_t n = getRank();
-
   // memrefs with identity layout are entirely contiguous.
   if (getLayout().isIdentity())
     return n;
@@ -716,27 +807,7 @@ int64_t MemRefType::getNumContiguousTrailingDims() {
   SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(strides, offset)))
     return 0;
-
-  ArrayRef<int64_t> shape = getShape();
-
-  // A memref with dimensions `d0, d1, ..., dn-1` and strides
-  // `s0, s1, ..., sn-1` is contiguous up to dimension `k`
-  // if each stride `si` is the product of the dimensions `di+1, ..., dn-1`,
-  // for `i` in `[k, n-1]`.
-  // Ignore stride elements if the corresponding dimension is 1, as they are
-  // of no consequence.
-  int64_t dimProduct = 1;
-  for (int64_t i = n - 1; i >= 0; --i) {
-    if (shape[i] == 1)
-      continue;
-    if (strides[i] != dimProduct)
-      return n - i - 1;
-    if (shape[i] == ShapedType::kDynamic)
-      return n - i;
-    dimProduct *= shape[i];
-  }
-
-  return n;
+  return getNumContiguousTrailingDimsImpl(getShape(), strides);
 }
 
 MemRefType MemRefType::canonicalizeStridedLayout() {



More information about the Mlir-commits mailing list