[Mlir-commits] [mlir] 76c0798 - [mlir][memref]: Allow collapse dummy strided unit dim (#103719)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Aug 21 06:35:31 PDT 2024
Author: Aviad Cohen
Date: 2024-08-21T16:35:26+03:00
New Revision: 76c07984257b49dcc4786fa9fb3918a2c1342e23
URL: https://github.com/llvm/llvm-project/commit/76c07984257b49dcc4786fa9fb3918a2c1342e23
DIFF: https://github.com/llvm/llvm-project/commit/76c07984257b49dcc4786fa9fb3918a2c1342e23.diff
LOG: [mlir][memref]: Allow collapse dummy strided unit dim (#103719)
Dimensions of size 1 should be skipped, because their strides are meaningless and could have any arbitrary value.
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/ops.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ff25de7295f6e..150049e5c5effe 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,6 +2448,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (strict && (stride.saturated || srcStride.saturated))
return failure();
+ // Dimensions of size 1 should be skipped, because their strides are
+ // meaningless and could have any arbitrary value.
+ if (srcShape[idx - 1] == 1)
+ continue;
+
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index b60894377f22fc..f616f6795bf9dc 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -99,7 +99,9 @@ func.func @expand_collapse_shape_static(
%arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
%arg5: memref<f32>,
%arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
- %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) {
+ %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>,
+ %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
+ %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
// Reshapes that collapse and expand back a contiguous buffer.
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
@@ -163,6 +165,19 @@ func.func @expand_collapse_shape_static(
memref<1x2049xi64, strided<[?, ?], offset: ?>> into
memref<2049xi64, strided<[?], offset: ?>>
+ // %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
+ // %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+ %r8 = memref.collapse_shape %arg8 [[0, 1, 2]] :
+ memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into
+ memref<1024xi8, strided<[1], offset: 0>>
+
+// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
+ %r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
+ memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>> into
+ memref<24x1024xi8, strided<[40960, 1], offset: 0>>
+
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
More information about the Mlir-commits
mailing list