[Mlir-commits] [mlir] 50826f9 - [mlir][MemRef] Add position-based matching heuristics for rank-reduction with dynamic strides (#184334)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 23:00:03 PST 2026


Author: Abhishek Varma
Date: 2026-03-05T06:59:58Z
New Revision: 50826f9c3b9c31bcc77846ba00ef106d0e0abc83

URL: https://github.com/llvm/llvm-project/commit/50826f9c3b9c31bcc77846ba00ef106d0e0abc83
DIFF: https://github.com/llvm/llvm-project/commit/50826f9c3b9c31bcc77846ba00ef106d0e0abc83.diff

LOG: [mlir][MemRef] Add position-based matching heuristics for rank-reduction with dynamic strides (#184334)

When multiple source dimensions have multiple unit dimensions,
stride-based disambiguation can be wrong with dynamic strides. Add
position-based matching: for each result dimension in order, pick the
leftmost unmatched source dimension with the same size; unmatched source
dims are dropped.

Example: subview from memref<1x8x1x3> to memref<1x8x3>. Both dim 0 and
dim 2 have size 1. Stride-based logic cannot distinguish when strides
are dynamic. Position-based matching correctly drops dim 2 (middle unit
dim) instead of dim 0.

When we have non-trivial static strides, we make use of the stride-based
logic, else we fall back to position-based logic as introduced by this
patch.

INPUT :-
```
func.func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim(
    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
    %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
  %c0 = arith.constant 0 : index
  %0 = memref.subview %arg0[0, 0, 0, 0][1, 8, 1, 3][1, 1, 1, 1]
      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
        memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
  %1 = memref.load %0[%c0, %arg1, %arg2] : memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
  return %1 : f32
}
```

WITHOUT this patch we get :-
```
memref.load %0[%c0, %c0, %arg1, %arg2]
```

WITH this patch we get :-
```
memref.load %0[%c0, %arg1, %c0, %arg2]
```

Signed-off-by: Abhishek Varma <abhvarma at amd.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 844e6183cff06..d36b72d5652c9 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -944,42 +944,66 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
   return numOccurences;
 }
 
-/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
-/// to be a subset of `originalType` with some `1` entries erased, return the
-/// set of indices that specifies which of the entries of `originalShape` are
-/// dropped to obtain `reducedShape`.
-/// This accounts for cases where there are multiple unit-dims, but only a
-/// subset of those are dropped. For MemRefTypes these can be disambiguated
-/// using the strides. If a dimension is dropped the stride must be dropped too.
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// For each result dimension in order, matches the leftmost unmatched source
+/// dimension with the same size. Source dimensions not matched are dropped.
+///
+/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
+/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
+/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
 static FailureOr<llvm::SmallBitVector>
-computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
-                               ArrayRef<OpFoldResult> sizes) {
-  llvm::SmallBitVector unusedDims(originalType.getRank());
-  if (originalType.getRank() == reducedType.getRank())
-    return unusedDims;
-
-  for (const auto &dim : llvm::enumerate(sizes))
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
-      if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
-        unusedDims.set(dim.index());
-
-  // Early exit for the case where the number of unused dims matches the number
-  // of ranks reduced.
-  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
-      originalType.getRank())
-    return unusedDims;
+computeMemRefRankReductionMaskByPosition(MemRefType originalType,
+                                         MemRefType reducedType,
+                                         ArrayRef<OpFoldResult> sizes) {
+  int64_t rankReduction = originalType.getRank() - reducedType.getRank();
+  if (rankReduction <= 0)
+    return llvm::SmallBitVector(originalType.getRank());
+
+  // Build source sizes from subview sizes (one per source dim).
+  SmallVector<int64_t> sourceSizes(originalType.getRank());
+  for (const auto &it : llvm::enumerate(sizes)) {
+    if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
+      sourceSizes[it.index()] = *cst;
+    else
+      sourceSizes[it.index()] = ShapedType::kDynamic;
+  }
+
+  ArrayRef<int64_t> resultSizes = reducedType.getShape();
+  llvm::SmallBitVector usedSourceDims(originalType.getRank());
+  int64_t startJ = 0;
+  for (int64_t resultSize : resultSizes) {
+    bool matched = false;
+    for (int64_t j = startJ; j < originalType.getRank(); ++j) {
+      if (sourceSizes[j] == resultSize) {
+        usedSourceDims.set(j);
+        matched = true;
+        startJ = j + 1;
+        break;
+      }
+    }
+    if (!matched)
+      return failure();
+  }
 
-  SmallVector<int64_t> originalStrides, candidateStrides;
-  int64_t originalOffset, candidateOffset;
-  if (failed(
-          originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
-      failed(
-          reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
-    return failure();
+  llvm::SmallBitVector unusedDims(originalType.getRank());
+  for (int64_t i = 0; i < originalType.getRank(); ++i)
+    if (!usedSourceDims.test(i))
+      unusedDims.set(i);
+  return unusedDims;
+}
 
-  // For memrefs, a dimension is truly dropped if its corresponding stride is
-  // also dropped. This is particularly important when more than one of the dims
-  // is 1. Track the number of occurences of the strides in the original type
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// A dimension is dropped if its stride is dropped; uses stride occurrence
+/// counting to disambiguate when multiple unit dims exist.
+///
+/// Example: memref<1x1x?xf32, strided<[?, 4, 1]>> to memref<1x4xf32,
+/// strided<[4, 1]>>. Source strides [?, 4, 1], candidate [4, 1]. Dim 0 (stride
+/// ?) can be dropped; dim 1 (stride 4) must be kept. Source dim 0 is dropped.
+static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
+    MemRefType originalType, MemRefType reducedType,
+    ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
+    llvm::SmallBitVector unusedDims) {
+  // Track the number of occurences of the strides in the original type
   // and the candidate type. For each unused dim that stride should not be
   // present in the candidate type. Note that there could be multiple dimensions
   // that have the same size. We dont need to exactly figure out which dim
@@ -1013,13 +1037,68 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
       return failure();
     }
   }
-
-  if ((int64_t)unusedDims.count() + reducedType.getRank() !=
+  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
       originalType.getRank())
     return failure();
   return unusedDims;
 }
 
+/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
+/// to be a subset of `originalType` with some `1` entries erased, return the
+/// set of indices that specifies which of the entries of `originalShape` are
+/// dropped to obtain `reducedShape`.
+/// This accounts for cases where there are multiple unit-dims, but only a
+/// subset of those are dropped. For MemRefTypes these can be disambiguated
+/// using the strides. If a dimension is dropped the stride must be dropped too.
+static FailureOr<llvm::SmallBitVector>
+computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
+                               ArrayRef<OpFoldResult> sizes) {
+  llvm::SmallBitVector unusedDims(originalType.getRank());
+  if (originalType.getRank() == reducedType.getRank())
+    return unusedDims;
+
+  for (const auto &dim : llvm::enumerate(sizes))
+    if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
+      if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
+        unusedDims.set(dim.index());
+
+  // Early exit for the case where the number of unused dims matches the number
+  // of ranks reduced.
+  if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
+      originalType.getRank())
+    return unusedDims;
+
+  SmallVector<int64_t> originalStrides, candidateStrides;
+  int64_t originalOffset, candidateOffset;
+  if (failed(
+          originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
+      failed(
+          reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
+    return failure();
+
+  // Try stride-based first when we have meaningful static stride info
+  // (preserves static strides). Fall back to position-based otherwise.
+  auto hasNonTrivialStaticStride = [](ArrayRef<int64_t> strides) {
+    // The innermost stride 1 is trivial for row-major and does not help
+    // disambiguate.
+    if (strides.size() <= 1)
+      return false;
+    return llvm::any_of(strides.drop_back(),
+                        [](int64_t s) { return !ShapedType::isDynamic(s); });
+  };
+  if (hasNonTrivialStaticStride(originalStrides) ||
+      hasNonTrivialStaticStride(candidateStrides)) {
+    FailureOr<llvm::SmallBitVector> strideBased =
+        computeMemRefRankReductionMaskByStrides(originalType, reducedType,
+                                                originalStrides,
+                                                candidateStrides, unusedDims);
+    if (succeeded(strideBased))
+      return *strideBased;
+  }
+  return computeMemRefRankReductionMaskByPosition(originalType, reducedType,
+                                                  sizes);
+}
+
 llvm::SmallBitVector SubViewOp::getDroppedDims() {
   MemRefType sourceType = getSourceType();
   MemRefType resultType = getType();

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3cfea1e8cd961..f0193d066dc37 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -47,7 +47,7 @@ func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
 //       CHECK: func @subview_of_strides_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, strided{{.*}}>
 //       CHECK:   %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
-//  CHECK-SAME:                    to memref<1x4xf32, strided<[7, 1], offset: ?>>
+//  CHECK-SAME:                    to memref<1x4xf32, strided<[35, 1], offset: ?>>
 //       CHECK:   %[[M:.+]] = memref.cast %[[S]]
 //  CHECK-SAME:                    to memref<1x4xf32, strided<[?, ?], offset: ?>>
 //       CHECK:   return %[[M]]

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 79156df0ebe1e..3f77a0553fff9 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -171,10 +171,31 @@ func.func @fold_rank_reducing_subview_with_load
 // CHECK-SAME:   %[[ARG15:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:   %[[ARG16:[a-zA-Z0-9_]+]]: index
 //  CHECK-DAG:   %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
-//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
-//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
-//  CHECK-DAG:   %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
-//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
+//  CHECK-DAG:   %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG14]], %[[ARG8]]]
+//  CHECK-DAG:   %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
+//  CHECK-DAG:   %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
+//      CHECK:   memref.load %[[ARG0]][%[[I0]], %[[I1]], %[[ARG3]], %[[I2]], %[[I3]], %[[ARG6]]]
+
+// -----
+
+func.func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim(
+    %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+    %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+  %c0 = arith.constant 0 : index
+  %0 = memref.subview %arg0[0, 0, 0, 0][1, 8, 1, 3][1, 1, 1, 1]
+      : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+        memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+  %1 = memref.load %0[%c0, %arg1, %arg2] : memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+  return %1 : f32
+}
+//      CHECK: func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//      CHECK:   memref.load %[[ARG0]][%[[C0]], %[[ARG1]], %[[C0]], %[[ARG2]]]
 
 // -----
 


        


More information about the Mlir-commits mailing list