[Mlir-commits] [mlir] Reapply [mlir][memref]: Allow collapse of strided unit dim even if strides are dynamic (PR #171039)
Maya Amrami
llvmlistbot at llvm.org
Thu Dec 11 02:50:39 PST 2025
https://github.com/amrami updated https://github.com/llvm/llvm-project/pull/171039
>From 01195e97edae56d6d223312877e9278c070055b6 Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Sun, 7 Dec 2025 15:50:54 +0200
Subject: [PATCH] Reapply [mlir][memref]: Allow collapse of strided unit dim
even if strides are dynamic
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++++-----
mlir/test/Dialect/MemRef/ops.mlir | 15 ++++++++++++++-
2 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..4639016fb589a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2563,6 +2563,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
for (int64_t idx : llvm::reverse(trailingReassocs)) {
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
// Both source and result stride must have the same static value. In that
// case, we can be sure, that the dimensions are collapsible (because they
// are contiguous).
@@ -2575,11 +2580,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
- // Dimensions of size 1 should be skipped, because their strides are
- // meaningless and could have any arbitrary value.
- if (srcShape[idx - 1] == 1)
- continue;
-
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index a90c9505a8405..cddc79f693b11 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -440,7 +440,10 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg4: index,
%arg5: index,
%arg6: index,
- %arg7: memref<4x?x4xf32>) {
+ %arg7: memref<4x?x4xf32>,
+ %arg8: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>>,
+ %arg9: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>>) {
+
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -489,6 +492,16 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
+// CHECK-SAME: memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+ %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xf32, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xf32, strided<[?, ?, 1], offset: ?>>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
+// CHECK-SAME: memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into memref<3x288xf32, strided<[288, 1], offset: 864>>
+ %6 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
+ memref<3x3x1x96xf32, strided<[288, 96, 96, 1], offset: 864>> into
+ memref<3x288xf32, strided<[288, 1], offset: 864>>
return
}
More information about the Mlir-commits
mailing list