[Mlir-commits] [mlir] [mlir][memref] Add better computeCollapsedLayoutMap support for unit collapse (PR #147967)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 10 06:52:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

Similar issue with 136485, but more concise

---
Full diff: https://github.com/llvm/llvm-project/pull/147967.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+5-1) 
- (added) mlir/test/Dialect/MemRef/collapse-strided.mlir (+44) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d1a9920aa66c5..ac8451ba0c45c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2422,7 +2422,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
     ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
     while (srcShape[ref.back()] == 1 && ref.size() > 1)
       ref = ref.drop_back();
-    if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
+    auto precedingRef = ref.drop_back();
+    bool allUnitPreceding = llvm::all_of(
+        precedingRef, [&srcShape](int idx) { return srcShape[idx] == 1; });
+    if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1 ||
+        allUnitPreceding) {
       resultStrides.push_back(srcStrides[ref.back()]);
     } else {
       // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
diff --git a/mlir/test/Dialect/MemRef/collapse-strided.mlir b/mlir/test/Dialect/MemRef/collapse-strided.mlir
new file mode 100644
index 0000000000000..6d82d2316c38e
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/collapse-strided.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: test_collapse(
+func.func @test_collapse(%arg0: memref<20x5xf32>, %arg1: index) {
+  %subview = memref.subview %arg0[0, 0] [1, %arg1] [1, 1] : memref<20x5xf32> to memref<1x?xf32, strided<[5, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1]] : memref<1x?xf32, strided<[5, 1]>> into memref<?xf32, strided<[1]>>
+  return
+}
+
+// CHECK-LABEL: test_collapse_5d_middle_dynamic(
+func.func @test_collapse_5d_middle_dynamic(%arg0: memref<8x5x6x9x2xf32>, %arg1: index) {
+  %subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 5, 1, %arg1, 1] [1, 1, 1, 1, 1]
+    : memref<8x5x6x9x2xf32> to memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3, 4]]
+    : memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<?xf32, strided<[?]>>
+  return
+}
+
+// CHECK-LABEL: test_collapse_5d_mostly_units(
+func.func @test_collapse_5d_mostly_units(%arg0: memref<3x4x5x8x2xf32>, %arg1: index) {
+  %subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 1, 1, %arg1, 1] [1, 1, 1, 1, 1]
+    : memref<3x4x5x8x2xf32> to memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3, 4]]
+    : memref<1x1x1x?x1xf32, strided<[320, 80, 16, 2, 1]>> into memref<?xf32, strided<[2]>>
+  return
+}
+
+// CHECK-LABEL: test_partial_collapse_6d(
+func.func @test_partial_collapse_6d(%arg0: memref<10x8x3x4x5x7xf32>, %arg1: index) {
+  %subview = memref.subview %arg0[0, 0, 0, 0, 0, 0] [1, %arg1, 1, 1, 5, 1] [1, 1, 1, 1, 1, 1]
+    : memref<10x8x3x4x5x7xf32> to memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1, 2, 3], [4, 5]]
+    : memref<1x?x1x1x5x1xf32, strided<[3360, 420, 140, 35, 7, 1]>> into memref<?x5xf32, strided<[420, 7]>>
+  return
+}
+
+// CHECK-LABEL: test_collapse_5d_grouped(
+func.func @test_collapse_5d_grouped(%arg0: memref<8x5x6x9x2xf32>, %arg1: index) {
+  %subview = memref.subview %arg0[0, 0, 0, 0, 0] [1, 5, 1, %arg1, 1] [1, 1, 1, 1, 1]
+    : memref<8x5x6x9x2xf32> to memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0], [1, 2, 3, 4]]
+    : memref<1x5x1x?x1xf32, strided<[540, 108, 18, 2, 1]>> into memref<1x?xf32, strided<[540, ?]>>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/147967


More information about the Mlir-commits mailing list