[Mlir-commits] [mlir] 1407f5b - [mlir] Canonicalize extract_slice(unpack) (#133777)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 1 11:52:03 PDT 2025


Author: Max191
Date: 2025-04-01T14:51:58-04:00
New Revision: 1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf

URL: https://github.com/llvm/llvm-project/commit/1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf
DIFF: https://github.com/llvm/llvm-project/commit/1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf.diff

LOG: [mlir] Canonicalize extract_slice(unpack) (#133777)

Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a
`linalg.unpack` with reduced dest sizes. This will only happen when the
unpack op's only user is a non rank-reducing slice with zero offset and
unit strides.

---------

Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
Signed-off-by: Max Dawkins <maxdawkins19 at gmail.com>
Co-authored-by: Max Dawkins <maxdawkins19 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..d6b093c5fb86b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
@@ -5243,6 +5244,29 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                              [&]() { unPackOp.setDpsInitOperand(0, newDest); });
     return success();
   }
+  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
+  if (unPackOp->hasOneUse()) {
+    auto extractSliceUser =
+        dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+    if (extractSliceUser &&
+        areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+        areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+        extractSliceUser.getSourceType().getRank() ==
+            extractSliceUser.getResultType().getRank()) {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(unPackOp);
+      auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+          unPackOp->getLoc(), unPackOp.getDest(),
+          extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+          extractSliceUser.getMixedStrides());
+      rewriter.modifyOpInPlace(unPackOp, [&]() {
+        unPackOp.setDpsInitOperand(0, newDest);
+        unPackOp.getResult().setType(newDest.getType());
+      });
+      rewriter.replaceOp(extractSliceUser, unPackOp);
+      return success();
+    }
+  }
 
   // Insert tensor.cast ops if static shape inference is available..
   SmallVector<int64_t> srcShape, destShape;

diff  --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
       into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
     return %unpack : tensor<7x?xi32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+    %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1, 2]
+      inner_dims_pos = [1, 2]
+      inner_tiles = [16, 16]
+      into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+  return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32x?xf32>
+//  CHECK-SAME:     %[[SIZE:.+]]: index
+//       CHECK:   %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+//  CHECK-SAME:     [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST_SLICE]]
+//       CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+  return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+    %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+  %unpack = linalg.unpack %src
+      outer_dims_perm = [0, 1]
+      inner_dims_pos = [1]
+      inner_tiles = [16]
+      into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+  %extracted_slice = tensor.extract_slice %unpack
+      [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+  return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+//  CHECK-SAME:     %[[SRC:.+]]: tensor<28x2x16xf32>
+//  CHECK-SAME:     %[[DEST:.+]]: tensor<28x32xf32>
+//       CHECK:   %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+//  CHECK-SAME:       into %[[DEST]]
+//       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+//       CHECK:   return %[[SLICE]]


        


More information about the Mlir-commits mailing list