[llvm-branch-commits] [mlir] 3414bbf - Revert "[mlir][memref]: Collapse strided unit dim even if strides are dynamic…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Nov 4 13:28:03 PST 2025
Author: Han-Chung Wang
Date: 2025-11-04T13:28:00-08:00
New Revision: 3414bbf12a661bc518953d91e859a5b67b6dc432
URL: https://github.com/llvm/llvm-project/commit/3414bbf12a661bc518953d91e859a5b67b6dc432
DIFF: https://github.com/llvm/llvm-project/commit/3414bbf12a661bc518953d91e859a5b67b6dc432.diff
LOG: Revert "[mlir][memref]: Collapse strided unit dim even if strides are dynamic…"
This reverts commit f74e90961f51c9437461007c89b037be41e4e887.
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e271ac58db327..1c21a2f270da6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2568,11 +2568,6 @@ computeCollapsedLayoutMap(MemRefType srcType,
auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
- // Dimensions of size 1 should be skipped, because their strides are
- // meaningless and could have any arbitrary value.
- if (srcShape[idx - 1] == 1)
- continue;
-
stride = stride * SaturatedInteger::wrap(srcShape[idx]);
// Both source and result stride must have the same static value. In that
@@ -2587,6 +2582,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index b1db99bb3ad08..a90c9505a8405 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -440,8 +440,7 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
%arg4: index,
%arg5: index,
%arg6: index,
- %arg7: memref<4x?x4xf32>,
- %arg8: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>>) {
+ %arg7: memref<4x?x4xf32>) {
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<?x?x?xf32> into memref<?x?xf32>
%0 = memref.collapse_shape %arg0 [[0, 1], [2]] :
@@ -490,10 +489,6 @@ func.func @expand_collapse_shape_dynamic(%arg0: memref<?x?x?xf32>,
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]]
%4 = memref.expand_shape %arg7 [[0, 1], [2], [3, 4]] output_shape [2, 2, %arg4, 2, 2]
: memref<4x?x4xf32> into memref<2x2x?x2x2xf32>
-
-// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2], [3]]
-// CHECK-SAME: memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
- %5 = memref.collapse_shape %arg8 [[0, 1], [2], [3]] : memref<1x1x18x?xsi8, strided<[?, ?, ?, 1], offset: ?>> into memref<1x18x?xsi8, strided<[?, ?, 1], offset: ?>>
return
}
More information about the llvm-branch-commits
mailing list