[Mlir-commits] [mlir] [mlir][MemRef] Add position-based matching heuristics for rank-reduction with dynamic strides (PR #184334)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 05:00:46 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Abhishek Varma (Abhishek-Varma)
<details>
<summary>Changes</summary>
When multiple source dimensions have multiple unit dimensions, stride-based disambiguation can be wrong with dynamic strides. Add position-based matching: for each result dimension in order, pick the leftmost unmatched source dimension with the same size; unmatched source dims are dropped.
Example: subview from memref<1x8x1x3> to memref<1x8x3>. Both dim 0 and dim 2 have size 1. Stride-based logic cannot distinguish when strides are dynamic. Position-based matching correctly drops dim 2 (middle unit dim) instead of dim 0.
Use position-based matching when multiple dimensions are being dropped, falling back to stride-based logic otherwise.
INPUT :-
```
func.func @<!-- -->fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim(
%arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
%c0 = arith.constant 0 : index
%0 = memref.subview %arg0[0, 0, 0, 0][1, 8, 1, 3][1, 1, 1, 1]
: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
%1 = memref.load %0[%c0, %arg1, %arg2] : memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
return %1 : f32
}
```
WITHOUT this patch we get :-
```
memref.load %0[%c0, %c0, %arg1, %arg2]
```
WITH this patch we get :-
```
memref.load %0[%c0, %arg1, %c0, %arg2]
```
Signed-off-by: Abhishek Varma <abhvarma@<!-- -->amd.com>
---
Full diff: https://github.com/llvm/llvm-project/pull/184334.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+62)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+1-1)
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+25-4)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 844e6183cff06..321bfc27516fc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -944,6 +944,56 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
return numOccurences;
}
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// For each result dimension in order, matches the leftmost unmatched source
+/// dimension with the same size. Source dimensions not matched are dropped.
+///
+/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
+/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
+/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
+static FailureOr<llvm::SmallBitVector>
+computeMemRefRankReductionMaskByPosition(MemRefType originalType,
+ MemRefType reducedType,
+ ArrayRef<OpFoldResult> sizes) {
+ int64_t rankReduction = originalType.getRank() - reducedType.getRank();
+ if (rankReduction <= 0)
+ return llvm::SmallBitVector(originalType.getRank());
+
+ // Build source sizes from subview sizes (one per source dim).
+ SmallVector<int64_t> sourceSizes(originalType.getRank());
+ for (const auto &it : llvm::enumerate(sizes)) {
+ if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
+ sourceSizes[it.index()] = *cst;
+ else
+ sourceSizes[it.index()] = ShapedType::kDynamic;
+ }
+
+ ArrayRef<int64_t> resultSizes = reducedType.getShape();
+ llvm::SmallBitVector usedSourceDims(originalType.getRank());
+ for (int64_t resultSize : resultSizes) {
+ bool matched = false;
+ for (int64_t j = 0; j < originalType.getRank(); ++j) {
+ if (usedSourceDims.test(j))
+ continue;
+ if (sourceSizes[j] == resultSize ||
+ (resultSize == ShapedType::kDynamic &&
+ sourceSizes[j] == ShapedType::kDynamic)) {
+ usedSourceDims.set(j);
+ matched = true;
+ break;
+ }
+ }
+ if (!matched)
+ return failure();
+ }
+
+ llvm::SmallBitVector unusedDims(originalType.getRank());
+ for (int64_t i = 0; i < originalType.getRank(); ++i)
+ if (!usedSourceDims.test(i))
+ unusedDims.set(i);
+ return unusedDims;
+}
+
/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
/// to be a subset of `originalType` with some `1` entries erased, return the
/// set of indices that specifies which of the entries of `originalShape` are
@@ -977,6 +1027,18 @@ computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
return failure();
+ // When strides are dynamic and multiple dimensions need to be dropped, we use
+ // position-based matching instead.
+ if (unusedDims.count() > 1 &&
+ (llvm::any_of(originalStrides, ShapedType::isDynamic) ||
+ llvm::any_of(candidateStrides, ShapedType::isDynamic))) {
+ FailureOr<llvm::SmallBitVector> positionBased =
+ computeMemRefRankReductionMaskByPosition(originalType, reducedType,
+ sizes);
+ if (succeeded(positionBased))
+ return *positionBased;
+ }
+
// For memrefs, a dimension is truly dropped if its corresponding stride is
// also dropped. This is particularly important when more than one of the dims
// is 1. Track the number of occurences of the strides in the original type
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 3cfea1e8cd961..f0193d066dc37 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -47,7 +47,7 @@ func.func @subview_of_size_memcast(%arg : memref<4x6x16x32xi8>) ->
// CHECK: func @subview_of_strides_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<1x1x?xf32, strided{{.*}}>
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][0, 0, 0] [1, 1, 4]
-// CHECK-SAME: to memref<1x4xf32, strided<[7, 1], offset: ?>>
+// CHECK-SAME: to memref<1x4xf32, strided<[35, 1], offset: ?>>
// CHECK: %[[M:.+]] = memref.cast %[[S]]
// CHECK-SAME: to memref<1x4xf32, strided<[?, ?], offset: ?>>
// CHECK: return %[[M]]
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 79156df0ebe1e..9f3e8deb48a4c 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -171,10 +171,31 @@ func.func @fold_rank_reducing_subview_with_load
// CHECK-SAME: %[[ARG15:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[I0:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG13]], %[[ARG7]]]
-// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG14]], %[[ARG9]]]
-// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG15]], %[[ARG10]]]
-// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG16]], %[[ARG11]]]
-// CHECK: memref.load %[[ARG0]][%[[I0]], %[[ARG2]], %[[I2]], %[[I3]], %[[I4]], %[[ARG6]]]
+// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG14]], %[[ARG8]]]
+// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG15]], %[[ARG9]]]
+// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG4]], %[[ARG16]], %[[ARG10]]]
+// CHECK: memref.load %[[ARG0]][%[[I0]], %[[I1]], %[[I2]], %[[I3]], %[[ARG5]], %[[ARG6]]]
+
+// -----
+
+func.func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim(
+ %arg0 : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>,
+ %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+ %c0 = arith.constant 0 : index
+ %0 = memref.subview %arg0[0, 0, 0, 0][1, 8, 1, 3][1, 1, 1, 1]
+ : memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>> to
+ memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+ %1 = memref.load %0[%c0, %arg1, %arg2] : memref<1x8x3xf32, strided<[?, ?, ?], offset: ?>>
+ return %1 : f32
+}
+// CHECK: func @fold_rank_reducing_subview_1x8x1x3_to_1x8x3_drop_middle_unit_dim
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?x?x?xf32, strided<[?, ?, ?, ?], offset: ?>>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: memref.load %[[ARG0]][%[[C0]], %[[ARG1]], %[[C0]], %[[ARG2]]]
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/184334
More information about the Mlir-commits
mailing list