[Mlir-commits] [mlir] [mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim (PR #136485)

Maya Amrami llvmlistbot at llvm.org
Mon Jun 23 08:02:14 PDT 2025


https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/136485

>From 7edff84214f6f75487dccc2d64c915338dd6b1ed Mon Sep 17 00:00:00 2001
From: Maya Amrami <mayaam88 at gmail.com>
Date: Sun, 20 Apr 2025 14:43:17 +0300
Subject: [PATCH] [mlir][memref] Fix computeCollapsedLayoutMap for contiguous
 dynamic dim

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++++++++++++-----
 mlir/test/Dialect/MemRef/invalid.mlir    |  8 ++++++++
 mlir/test/Dialect/MemRef/ops.mlir        | 21 ++++++++++++++++++++-
 3 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d56b32193765e..28770bffc45ad 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::memref;
@@ -2413,11 +2414,23 @@ computeCollapsedLayoutMap(MemRefType srcType,
     if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
       resultStrides.push_back(srcStrides[ref.back()]);
     } else {
-      // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
-      // the corresponding stride may have to be skipped. (See above comment.)
-      // Therefore, the result stride cannot be statically determined and must
-      // be dynamic.
-      resultStrides.push_back(ShapedType::kDynamic);
+      // We reach here if the last dimension in the reassociation group is dynamic,
+      // and the reassociation group has more than one dimension.
+      // If the dynamic dimension is preserved (all other dimensions in the group are of size 1),
+      // and the dynamic dimension is originally contiguous, the result stride will be 1.
+      bool contiguousSrcDim = srcStrides[ref.back()] == 1;
+      bool dynamicSizeIsPreserved =
+          std::all_of(ref.begin(), ref.end() - 1,
+                      [srcShape](int64_t dim) { return srcShape[dim] == 1; });
+      if (contiguousSrcDim && dynamicSizeIsPreserved)
+        resultStrides.push_back(1);
+      else {
+        // Dynamically-sized dims may turn out to be dims of size 1 at runtime,
+        // so the corresponding stride may have to be skipped. (See above
+        // comment.) Therefore, the result stride cannot be statically
+        // determined and must be dynamic.
+        resultStrides.push_back(ShapedType::kDynamic);
+      }
     }
   }
 
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f908efb638446..8c9d744a754ce 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {
 
 // -----
 
+func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
+  // expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
+  %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
+  return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
+}
+
+// -----
+
 func.func @expand_shape_illegal_static_memref
   (%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
   // expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 13fdf3cf13510..123ac1cf4de94 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -417,7 +417,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
          %arg4: index,
          %arg5: index,
          %arg6: index,
-         %arg7: memref<4x?x4xf32>) {
+         %arg7: memref<4x?x4xf32>,
+         %arg8: memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
+         %arg9: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>,
+         %arg10: memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1>) {
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<?x?x?xf32> into memref<?x?xf32>
   %0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -466,6 +469,22 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
   %4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
         : memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+//       CHECK:   collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
+  %5 = memref.collapse_shape %arg8 [[0, 1, 2, 3]] :
+     memref<1x2x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
+     memref<?xsi32, strided<[?]>, 1>
+
+//       CHECK:   collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
+  %6 = memref.collapse_shape %arg9 [[0, 1, 2, 3]] :
+     memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into
+     memref<?xsi32, strided<[1]>, 1>
+
+//       CHECK:   collapse_shape {{.*}} {{\[}}[0, 1, 2, 3]]
+  %7 = memref.collapse_shape %arg10 [[0, 1, 2, 3]] :
+     memref<1x1x?x1xsi32, strided<[36960, 4620, 2, 330]>, 1> into
+     memref<?xsi32, strided<[?]>, 1>
+
   return
 }
 



More information about the Mlir-commits mailing list