[Mlir-commits] [mlir] [mlir][MemRef] Add position-based matching heuristics for rank-reduction with dynamic strides (PR #184334)

Abhishek Varma llvmlistbot at llvm.org
Wed Mar 4 22:49:26 PST 2026


https://github.com/Abhishek-Varma updated https://github.com/llvm/llvm-project/pull/184334

>From 8cb79693ba71a580e4abf38feed59d3a4ae4a903 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 3 Mar 2026 08:45:04 +0000
Subject: [PATCH 1/4] [mlir][MemRef] Add position-based matching for
 rank-reduction

When multiple source dimensions have multiple unit dimensions,
stride-based disambiguation can be wrong with dynamic strides. Add
position-based matching: for each result dimension in order, pick the
leftmost unmatched source dimension with the same size; unmatched source
dims are dropped.

Example: subview from memref<1x8x1x3> to memref<1x8x3>. Both dim 0 and dim 2
have size 1. Stride-based logic cannot distinguish when strides are dynamic.
Position-based matching correctly drops dim 2 (middle unit dim) instead of
dim 0.

Use position-based matching when multiple dimensions are being dropped
(unusedDims.count() > 1), falling back to stride-based logic otherwise.

Signed-off-by: Abhishek Varma <abhvarma at amd.com>
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 61 +++++++++++++++++++
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 27 ++++----
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 33 ++++++++--
 3 files changed, 101 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 844e6183cff06..02fe8a178fdd2 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -944,6 +944,56 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
   return numOccurences;
 }
 
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// For each result dimension in order, matches the leftmost unmatched source
+/// dimension with the same size. Source dimensions not matched are dropped.
+///
+/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
+/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
+/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
+static FailureOr<llvm::SmallBitVector>
+computeMemRefRankReductionMaskByPosition(MemRefType originalType,
+                                         MemRefType reducedType,
+                                         ArrayRef<OpFoldResult> sizes) {
+  int64_t rankReduction = originalType.getRank() - reducedType.getRank();
+  if (rankReduction <= 0)
+    return llvm::SmallBitVector(originalType.getRank());
+
+  // Build source sizes from subview sizes (one per source dim).
+  SmallVector<int64_t> sourceSizes(originalType.getRank());
+  for (const auto &it : llvm::enumerate(sizes)) {
+    if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
+      sourceSizes[it.index()] = *cst;
+    else
+      sourceSizes[it.index()] = ShapedType::kDynamic;
+  }
+
+  ArrayRef<int64_t> resultSizes = reducedType.getShape();
+  llvm::SmallBitVector usedSourceDims(originalType.getRank());
+  for (int64_t resultSize : resultSizes) {
+    bool matched = false;
+    for (int64_t j = 0; j < originalType.getRank(); ++j) {
+      if (usedSourceDims.test(j))
+        continue;
+      if (sourceSizes[j] == resultSize ||
+          (resultSize == ShapedType::kDynamic &&
+           sourceSizes[j] == ShapedType::kDynamic)) {
+        usedSourceDims.set(j);
+        matched = true;
+        break;
+      }
+    }
+    if (!matched)
+      return failure();
+  }
+
+  llvm::SmallBitVector unusedDims(originalType.getRank());
+  for (int64_t i = 0; i < originalType.getRank(); ++i)
+    if (!usedSourceDims.test(i))
+      unusedDims.set(i);
+  return unusedDims;
+}
+
 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
 /// to be a subset of `originalType` with some `1` entries erased, return the
 /// set of indices that specifies which of the entries of `originalShape` are
@@ -969,6 +1019,17 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
       originalType.getRank())
     return unusedDims;
 
+  // Stride-based logic can be wrong when multiple dims share the same size
+  // (e.g. 1x1x384 -> 1x384) or when strides are dynamic. Try position-based
+  // matching first; it is deterministic and matches subview semantics.
+  if (unusedDims.count() > 1) {
+    FailureOr<llvm::SmallBitVector> positionBased =
+        computeMemRefRankReductionMaskByPosition(originalType, reducedType,
+                                                 sizes);
+    if (succeeded(positionBased))
+      return *positionBased;
+  }
+
   SmallVector<int64_t> originalStrides, candidateStrides;
   int64_t originalOffset, candidateOffset;
   if (failed(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3cfea1e8cd961..a86ae5de5391b 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -47,7 +47,7 @@ func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
 //       CHECK: func @subview_of_strides_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, strided{{.*}}>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
-//  CHECK-SAME:                    to memref<1x4xf32, strided<[7, 1], offset: ?>>
+//  CHECK-SAME:                    to memref<1x4xf32, strided<[35, 1], offset: ?>>
 //       CHECK:   %[[M:.+]] = memref.cast %[[S]]
 //  CHECK-SAME:                    to memref<1x4xf32, strided<[?, ?], offset: ?>>
 //       CHECK:   return %[[M]]
@@ -124,16 +124,17 @@ func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 :
 func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>,
     %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>>
 {
-  %c1 = arith.constant 1 : index
-  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
-  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+  // Use literal 1 (not %c1) so static sizes are [1, 1, kDynamic] -> expected shape [1,1,?].
+  // Dropping dim 1 yields [1,?] with strides [147456, 1].
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, 1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
+  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<1x?xf32, strided<[147456, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
 //       CHECK: func @multiple_reducing_dims
 //       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
-//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
+//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
 //       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
-//  CHECK-SAME:       : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+//  CHECK-SAME:       : memref<1x?xf32, strided<[147456, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
 
 // -----
 
@@ -1487,14 +1488,12 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
 // CHECK-LABEL: func @subview_rank_reduction(
 //  CHECK-SAME:     %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
 func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
-    -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
-  %c1 = arith.constant 1 : index
-  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
-  // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
-  %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
-      : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
-  // CHECK: return %[[cast]]
-  return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
+    -> memref<1x?xf32, strided<[147456, 1], offset: ?>> {
+  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
+  // CHECK: return %[[subview]]
+  %0 = memref.subview %arg0[0, %idx, %idx] [1, 1, %idx] [1, 1, 1]
+      : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
+  return %0 : memref<1x?xf32, strided<[147456, 1], offset: ?>>
 }
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 79156df0ebe1e..3a2401df25867 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -171,10 +171,31 @@ func.func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG15:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
-//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
-//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
-//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG14]], %[[ARG8]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG15]], %[[ARG9]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG16]], %[[ARG10]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[ARG5]], %[[ARG6]]]
+
+// -----
+
+func.func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim(
+    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+    %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+  %c0 = arith.constant 0 : index
+  %0 = memref.subview %arg0[0, 0, 0, 0][1, 8, 1, 3][1, 1, 1, 1]
+      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+        memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+  %1 = memref.load %0[%c0, %arg1, %arg2] : memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+  return %1 : f32
+}
+//      CHECK: func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   memref.load %[[ARG0]][%[[C0]], %[[ARG1]], %[[C0]], %[[ARG2]]]
 
 // -----
 
@@ -611,9 +632,9 @@ func.func @fold_src_fold_dest_nvgpu_device_async_copy(%gmem_memref_3d : memref<2
 //  CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index, %[[DEST_IDX_0:.+]]: index, %[[DEST_IDX_1:.+]]: index, %[[DEST_IDX_2:.+]]: index, %[[DEST_IDX_3:.+]]: index, %[[DEST_SUB_IDX_0:.+]]: index, %[[DEST_SUB_IDX_1:.+]]: index)
 //   CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]]
 //   CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]]
-//   CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]]
+//   CHECK-DAG: %[[RESOLVED_DST_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_0]], %[[DEST_SUB_IDX_0]]]
 //   CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]]
-//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
+//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[RESOLVED_DST_IDX_0]], %[[DEST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
 
 // -----
 

>From a13042dc5cc89d39e7793bff146df18fad8f42a0 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Tue, 3 Mar 2026 12:48:44 +0000
Subject: [PATCH 2/4] Fix condition + lit tests

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 23 +++++++++--------
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 25 ++++++++++---------
 .../Dialect/MemRef/fold-memref-alias-ops.mlir |  4 +--
 3 files changed, 27 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 02fe8a178fdd2..321bfc27516fc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1019,17 +1019,6 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
       originalType.getRank())
     return unusedDims;
 
-  // Stride-based logic can be wrong when multiple dims share the same size
-  // (e.g. 1x1x384 -> 1x384) or when strides are dynamic. Try position-based
-  // matching first; it is deterministic and matches subview semantics.
-  if (unusedDims.count() > 1) {
-    FailureOr<llvm::SmallBitVector> positionBased =
-        computeMemRefRankReductionMaskByPosition(originalType, reducedType,
-                                                 sizes);
-    if (succeeded(positionBased))
-      return *positionBased;
-  }
-
   SmallVector<int64_t> originalStrides, candidateStrides;
   int64_t originalOffset, candidateOffset;
   if (failed(
@@ -1038,6 +1027,18 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
           reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
     return failure();
 
+  // When strides are dynamic and multiple dimensions need to be dropped, we use
+  // position-based matching instead.
+  if (unusedDims.count() > 1 &&
+      (llvm::any_of(originalStrides, ShapedType::isDynamic) ||
+       llvm::any_of(candidateStrides, ShapedType::isDynamic))) {
+    FailureOr<llvm::SmallBitVector> positionBased =
+        computeMemRefRankReductionMaskByPosition(originalType, reducedType,
+                                                 sizes);
+    if (succeeded(positionBased))
+      return *positionBased;
+  }
+
   // For memrefs, a dimension is truly dropped if its corresponding stride is
   // also dropped. This is particularly important when more than one of the dims
   // is 1. Track the number of occurences of the strides in the original type
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a86ae5de5391b..f0193d066dc37 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -124,17 +124,16 @@ func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 :
 func.func @multiple_reducing_dims(%arg0 : memref<1x384x384xf32>,
     %arg1 : index, %arg2 : index, %arg3 : index) -> memref<?xf32, strided<[1], offset: ?>>
 {
-  // Use literal 1 (not %c1) so static sizes are [1, 1, kDynamic] -> expected shape [1,1,?].
-  // Dropping dim 1 yields [1,?] with strides [147456, 1].
-  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, 1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
-  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<1x?xf32, strided<[147456, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+  %c1 = arith.constant 1 : index
+  %0 = memref.subview %arg0[0, %arg1, %arg2] [1, %c1, %arg3] [1, 1, 1] : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+  %1 = memref.subview %0[0, 0] [1, %arg3] [1, 1] : memref<?x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
   return %1 : memref<?xf32, strided<[1], offset: ?>>
 }
 //       CHECK: func @multiple_reducing_dims
 //       CHECK:   %[[REDUCED1:.+]] = memref.subview %{{.+}}[0, %{{.+}}, %{{.+}}] [1, 1, %{{.+}}] [1, 1, 1]
-//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
+//  CHECK-SAME:       : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
 //       CHECK:   %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
-//  CHECK-SAME:       : memref<1x?xf32, strided<[147456, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
+//  CHECK-SAME:       : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
 
 // -----
 
@@ -1488,12 +1487,14 @@ func.func private @ub_negative_alloc_size() -> memref<?x?x?xi1> {
 // CHECK-LABEL: func @subview_rank_reduction(
 //  CHECK-SAME:     %[[arg0:.*]]: memref<1x384x384xf32>, %[[arg1:.*]]: index
 func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
-    -> memref<1x?xf32, strided<[147456, 1], offset: ?>> {
-  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
-  // CHECK: return %[[subview]]
-  %0 = memref.subview %arg0[0, %idx, %idx] [1, 1, %idx] [1, 1, 1]
-      : memref<1x384x384xf32> to memref<1x?xf32, strided<[147456, 1], offset: ?>>
-  return %0 : memref<1x?xf32, strided<[147456, 1], offset: ?>>
+    -> memref<?x?xf32, strided<[384, 1], offset: ?>> {
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[subview:.*]] = memref.subview %[[arg0]][0, %[[arg1]], %[[arg1]]] [1, 1, %[[arg1]]] [1, 1, 1] : memref<1x384x384xf32> to memref<1x?xf32, strided<[384, 1], offset: ?>>
+  // CHECK: %[[cast:.*]] = memref.cast %[[subview]] : memref<1x?xf32, strided<[384, 1], offset: ?>> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+  %0 = memref.subview %arg0[0, %idx, %idx] [1, %c1, %idx] [1, 1, 1]
+      : memref<1x384x384xf32> to memref<?x?xf32, strided<[384, 1], offset: ?>>
+  // CHECK: return %[[cast]]
+  return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
 }
 
 // -----
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3a2401df25867..9f3e8deb48a4c 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -632,9 +632,9 @@ func.func @fold_src_fold_dest_nvgpu_device_async_copy(%gmem_memref_3d : memref<2
 //  CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[SRC_IDX_0:.+]]: index, %[[SRC_IDX_1:.+]]: index, %[[SRC_IDX_2:.+]]: index, %[[SRC_SUB_IDX_0:.+]]: index, %[[SRC_SUB_IDX_1:.+]]: index, %[[DEST_IDX_0:.+]]: index, %[[DEST_IDX_1:.+]]: index, %[[DEST_IDX_2:.+]]: index, %[[DEST_IDX_3:.+]]: index, %[[DEST_SUB_IDX_0:.+]]: index, %[[DEST_SUB_IDX_1:.+]]: index)
 //   CHECK-DAG: %[[RESOLVED_SRC_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_0]], %[[SRC_SUB_IDX_0]]]
 //   CHECK-DAG: %[[RESOLVED_SRC_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[SRC_IDX_2]], %[[SRC_SUB_IDX_1]]]
-//   CHECK-DAG: %[[RESOLVED_DST_IDX_0:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_0]], %[[DEST_SUB_IDX_0]]]
+//   CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]]
 //   CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]]
-//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[RESOLVED_DST_IDX_0]], %[[DEST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
+//   CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space<workgroup>>
 
 // -----
 

>From 32ed9292c0432250fd510e2a7e4b7b169a7058b4 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Wed, 4 Mar 2026 06:01:32 +0000
Subject: [PATCH 3/4] Review comment v1.0

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 123 +++++++++---------
 .../Dialect/MemRef/fold-memref-alias-ops.mlir |   6 +-
 2 files changed, 67 insertions(+), 62 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 321bfc27516fc..ffff41f2ec6b5 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -970,16 +970,14 @@ computeMemRefRankReductionMaskByPosition(MemRefType originalType,
 
   ArrayRef<int64_t> resultSizes = reducedType.getShape();
   llvm::SmallBitVector usedSourceDims(originalType.getRank());
+  int64_t startJ = 0;
   for (int64_t resultSize : resultSizes) {
     bool matched = false;
-    for (int64_t j = 0; j < originalType.getRank(); ++j) {
-      if (usedSourceDims.test(j))
-        continue;
-      if (sourceSizes[j] == resultSize ||
-          (resultSize == ShapedType::kDynamic &&
-           sourceSizes[j] == ShapedType::kDynamic)) {
+    for (int64_t j = startJ; j < originalType.getRank(); ++j) {
+      if (sourceSizes[j] == resultSize) {
         usedSourceDims.set(j);
         matched = true;
+        startJ = j + 1;
         break;
       }
     }
@@ -994,6 +992,45 @@ computeMemRefRankReductionMaskByPosition(MemRefType originalType,
   return unusedDims;
 }
 
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// A dimension is dropped if its stride is dropped; uses stride occurrence
+/// counting to disambiguate when multiple unit dims exist.
+///
+/// Example: memref<1x1x?xf32, strided<[?, 4, 1]>> to memref<1x4xf32,
+/// strided<[4, 1]>>. Source strides [?, 4, 1], candidate [4, 1]. Dim 0 (stride
+/// ?) can be dropped; dim 1 (stride 4) must be kept. Source dim 0 is dropped.
+static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
+    MemRefType originalType, MemRefType reducedType,
+    ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
+    llvm::SmallBitVector unusedDims) {
+  std::map<int64_t, unsigned> currUnaccountedStrides =
+      getNumOccurences(originalStrides);
+  std::map<int64_t, unsigned> candidateStridesNumOccurences =
+      getNumOccurences(candidateStrides);
+  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
+    if (!unusedDims.test(dim))
+      continue;
+    int64_t originalStride = originalStrides[dim];
+    if (currUnaccountedStrides[originalStride] >
+        candidateStridesNumOccurences[originalStride]) {
+      currUnaccountedStrides[originalStride]--;
+      continue;
+    }
+    if (currUnaccountedStrides[originalStride] ==
+        candidateStridesNumOccurences[originalStride]) {
+      unusedDims.reset(dim);
+      continue;
+    }
+    if (currUnaccountedStrides[originalStride] <
+        candidateStridesNumOccurences[originalStride])
+      return failure();
+  }
+  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
+      originalType.getRank())
+    return failure();
+  return unusedDims;
+}
+
 /// Given the `originalType` and a `candidateReducedType` whose shape is assumed
 /// to be a subset of `originalType` with some `1` entries erased, return the
 /// set of indices that specifies which of the entries of `originalShape` are
@@ -1027,59 +1064,27 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
           reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
     return failure();
 
-  // When strides are dynamic and multiple dimensions need to be dropped, we use
-  // position-based matching instead.
-  if (unusedDims.count() > 1 &&
-      (llvm::any_of(originalStrides, ShapedType::isDynamic) ||
-       llvm::any_of(candidateStrides, ShapedType::isDynamic))) {
-    FailureOr<llvm::SmallBitVector> positionBased =
-        computeMemRefRankReductionMaskByPosition(originalType, reducedType,
-                                                 sizes);
-    if (succeeded(positionBased))
-      return *positionBased;
-  }
-
-  // For memrefs, a dimension is truly dropped if its corresponding stride is
-  // also dropped. This is particularly important when more than one of the dims
-  // is 1. Track the number of occurences of the strides in the original type
-  // and the candidate type. For each unused dim that stride should not be
-  // present in the candidate type. Note that there could be multiple dimensions
-  // that have the same size. We dont need to exactly figure out which dim
-  // corresponds to which stride, we just need to verify that the number of
-  // reptitions of a stride in the original + number of unused dims with that
-  // stride == number of repititions of a stride in the candidate.
-  std::map<int64_t, unsigned> currUnaccountedStrides =
-      getNumOccurences(originalStrides);
-  std::map<int64_t, unsigned> candidateStridesNumOccurences =
-      getNumOccurences(candidateStrides);
-  for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
-    if (!unusedDims.test(dim))
-      continue;
-    int64_t originalStride = originalStrides[dim];
-    if (currUnaccountedStrides[originalStride] >
-        candidateStridesNumOccurences[originalStride]) {
-      // This dim can be treated as dropped.
-      currUnaccountedStrides[originalStride]--;
-      continue;
-    }
-    if (currUnaccountedStrides[originalStride] ==
-        candidateStridesNumOccurences[originalStride]) {
-      // The stride for this is not dropped. Keep as is.
-      unusedDims.reset(dim);
-      continue;
-    }
-    if (currUnaccountedStrides[originalStride] <
-        candidateStridesNumOccurences[originalStride]) {
-      // This should never happen. Cant have a stride in the reduced rank type
-      // that wasnt in the original one.
-      return failure();
-    }
-  }
-
-  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
-      originalType.getRank())
-    return failure();
-  return unusedDims;
+  // Try stride-based first when we have meaningful static stride info
+  // (preserves static strides). Fall back to position-based otherwise.
+  auto hasNonTrivialStaticStride = [](ArrayRef<int64_t> strides) {
+    // The innermost stride 1 is trivial for row-major and does not help
+    // disambiguate.
+    if (strides.size() <= 1)
+      return false;
+    return llvm::any_of(strides.drop_back(),
+                        [](int64_t s) { return !ShapedType::isDynamic(s); });
+  };
+  if (hasNonTrivialStaticStride(originalStrides) ||
+      hasNonTrivialStaticStride(candidateStrides)) {
+    FailureOr<llvm::SmallBitVector> strideBased =
+        computeMemRefRankReductionMaskByStrides(originalType, reducedType,
+                                                originalStrides,
+                                                candidateStrides, unusedDims);
+    if (succeeded(strideBased))
+      return *strideBased;
+  }
+  return computeMemRefRankReductionMaskByPosition(originalType, reducedType,
+                                                  sizes);
 }
 
 llvm::SmallBitVector SubViewOp::getDroppedDims() {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 9f3e8deb48a4c..3f77a0553fff9 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -172,9 +172,9 @@ func.func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
 //  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG14]], %[[ARG8]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG15]], %[[ARG9]]]
-//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG16]], %[[ARG10]]]
-//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[ARG5]], %[[ARG6]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[I1]], %[[ARG3]], %[[I2]], %[[I3]], %[[ARG6]]]
 
 // -----
 

>From cb0c647f6cc3db19197bdfc9780aa96b26c95b93 Mon Sep 17 00:00:00 2001
From: Abhishek Varma <abhvarma at amd.com>
Date: Thu, 5 Mar 2026 06:46:52 +0000
Subject: [PATCH 4/4] Restore comments

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ffff41f2ec6b5..d36b72d5652c9 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1003,6 +1003,13 @@ static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
     MemRefType originalType, MemRefType reducedType,
     ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
     llvm::SmallBitVector unusedDims) {
+  // Track the number of occurences of the strides in the original type
+  // and the candidate type. For each unused dim that stride should not be
+  // present in the candidate type. Note that there could be multiple dimensions
+  // that have the same size. We dont need to exactly figure out which dim
+  // corresponds to which stride, we just need to verify that the number of
+  // reptitions of a stride in the original + number of unused dims with that
+  // stride == number of repititions of a stride in the candidate.
   std::map<int64_t, unsigned> currUnaccountedStrides =
       getNumOccurences(originalStrides);
   std::map<int64_t, unsigned> candidateStridesNumOccurences =
@@ -1013,17 +1020,22 @@ static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
     int64_t originalStride = originalStrides[dim];
     if (currUnaccountedStrides[originalStride] >
         candidateStridesNumOccurences[originalStride]) {
+      // This dim can be treated as dropped.
       currUnaccountedStrides[originalStride]--;
       continue;
     }
     if (currUnaccountedStrides[originalStride] ==
         candidateStridesNumOccurences[originalStride]) {
+      // The stride for this is not dropped. Keep as is.
       unusedDims.reset(dim);
       continue;
     }
     if (currUnaccountedStrides[originalStride] <
-        candidateStridesNumOccurences[originalStride])
+        candidateStridesNumOccurences[originalStride]) {
+      // This should never happen. Cant have a stride in the reduced rank type
+      // that wasnt in the original one.
       return failure();
+    }
   }
   if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
       originalType.getRank())



More information about the Mlir-commits mailing list