[Mlir-commits] [mlir] 76c0798 - [mlir][memref]: Allow collapse dummy strided unit dim (#103719)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 21 06:35:31 PDT 2024


Author: Aviad Cohen
Date: 2024-08-21T16:35:26+03:00
New Revision: 76c07984257b49dcc4786fa9fb3918a2c1342e23

URL: https://github.com/llvm/llvm-project/commit/76c07984257b49dcc4786fa9fb3918a2c1342e23
DIFF: https://github.com/llvm/llvm-project/commit/76c07984257b49dcc4786fa9fb3918a2c1342e23.diff

LOG: [mlir][memref]: Allow collapse dummy strided unit dim (#103719)

Dimensions of size 1 should be skipped, because their strides are meaningless and could have any arbitrary value.

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 0ff25de7295f6e..150049e5c5effe 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2448,6 +2448,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
+      // Dimensions of size 1 should be skipped, because their strides are
+      // meaningless and could have any arbitrary value.
+      if (srcShape[idx - 1] == 1)
+        continue;
+
       if (!stride.saturated && !srcStride.saturated && stride != srcStride)
         return failure();
     }

diff  --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index b60894377f22fc..f616f6795bf9dc 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -99,7 +99,9 @@ func.func @expand_collapse_shape_static(
     %arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
     %arg5: memref<f32>,
     %arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
-    %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) {
+    %arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>,
+    %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
+    %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
   // Reshapes that collapse and expand back a contiguous buffer.
 //       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<12x5xf32>
@@ -163,6 +165,19 @@ func.func @expand_collapse_shape_static(
       memref<1x2049xi64, strided<[?, ?], offset: ?>> into
       memref<2049xi64, strided<[?], offset: ?>>
 
+    // %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
+    // %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
+  %r8 = memref.collapse_shape %arg8 [[0, 1, 2]] :
+      memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into
+      memref<1024xi8, strided<[1], offset: 0>>
+
+//       CHECK:   memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
+  %r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
+      memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>> into
+      memref<24x1024xi8, strided<[40960, 1], offset: 0>>
+
   // Reshapes that expand and collapse back a contiguous buffer with some 1's.
 //       CHECK:   memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
 //  CHECK-SAME:     memref<3x4x5xf32> into memref<1x3x4x1x5xf32>


        


More information about the Mlir-commits mailing list