[Mlir-commits] [mlir] [mlir] Canonicalize extract_slice(unpack) (PR #133777)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 31 11:51:19 PDT 2025


https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/133777

Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a `linalg.unpack` with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides.

>From 1a46dc2b3f1aa898806d218e30c1f36c455dd093 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 31 Mar 2025 13:21:58 -0500
Subject: [PATCH] [mlir] Canonicalize extract_slice(unpack)

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 20 ++++++
 mlir/test/Dialect/Linalg/canonicalize.mlir | 75 ++++++++++++++++++++++
 2 files changed, 95 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..fc5d8472a9a7b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5243,6 +5243,26 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+  /// extract_slice(unpack(x)) -> unpack(x)
+  if (unPackOp->hasOneUse()) {
+    auto extractSliceUser =
+        dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+    if (extractSliceUser &&
+        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+        extractSliceUser.getSourceType().getRank() ==
+            extractSliceUser.getResultType().getRank()) {
+      auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+          unPackOp->getLoc(), unPackOp.getDest(),
+          extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+          extractSliceUser.getMixedStrides());
+      rewriter.replaceOpWithNewOp<UnPackOp>(
+          extractSliceUser, unPackOp.getSource(), newDest,
+          unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+          unPackOp.getOuterDimsPerm());
+      return success();
+    }
+  }
 
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
       into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
     return %unpack : tensor<7x?xi32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
+//  CHECK-SAME:     %[[SIZE:.+]]: index
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+  return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+  return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]



More information about the Mlir-commits mailing list