[Mlir-commits] [mlir] [mlir][MemRef] Add position-based matching heuristics for rank-reduction with dynamic strides (PR #184334)

Han-Chung Wang llvmlistbot at llvm.org
Tue Mar 3 11:09:11 PST 2026


================
@@ -944,6 +944,56 @@ static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
   return numOccurences;
 }
 
+/// Returns the set of source dimensions that are dropped in a rank reduction.
+/// For each result dimension in order, matches the leftmost unmatched source
+/// dimension with the same size. Source dimensions not matched are dropped.
+///
+/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
+/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
+/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
+static FailureOr<llvm::SmallBitVector>
+computeMemRefRankReductionMaskByPosition(MemRefType originalType,
+                                         MemRefType reducedType,
+                                         ArrayRef<OpFoldResult> sizes) {
+  int64_t rankReduction = originalType.getRank() - reducedType.getRank();
+  if (rankReduction <= 0)
+    return llvm::SmallBitVector(originalType.getRank());
+
+  // Build source sizes from subview sizes (one per source dim).
+  SmallVector<int64_t> sourceSizes(originalType.getRank());
+  for (const auto &it : llvm::enumerate(sizes)) {
+    if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
+      sourceSizes[it.index()] = *cst;
+    else
+      sourceSizes[it.index()] = ShapedType::kDynamic;
+  }
+
+  ArrayRef<int64_t> resultSizes = reducedType.getShape();
+  llvm::SmallBitVector usedSourceDims(originalType.getRank());
+  for (int64_t resultSize : resultSizes) {
+    bool matched = false;
+    for (int64_t j = 0; j < originalType.getRank(); ++j) {
----------------
hanhanW wrote:

I think we should track `startJ`. Otherwise, you would mismatch for the following example:

- Subview sizes: [4, 1, 1, 4, 1, 1]
- Result shape: [4, 1, 4, 1]

`result[1]` matches `subview[1]` and `result[3]` matches `subview[2]`, which is incorrect.

https://github.com/llvm/llvm-project/pull/184334


More information about the Mlir-commits mailing list