[Mlir-commits] [mlir] [mlir] Canonicalize extract_slice(unpack) (PR #133777)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 31 11:51:19 PDT 2025
https://github.com/Max191 created https://github.com/llvm/llvm-project/pull/133777
Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a `linalg.unpack` with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides.
>From 1a46dc2b3f1aa898806d218e30c1f36c455dd093 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 31 Mar 2025 13:21:58 -0500
Subject: [PATCH] [mlir] Canonicalize extract_slice(unpack)
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 20 ++++++
mlir/test/Dialect/Linalg/canonicalize.mlir | 75 ++++++++++++++++++++++
2 files changed, 95 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..fc5d8472a9a7b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5243,6 +5243,26 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+ /// extract_slice(unpack(x)) -> unpack(x)
+ if (unPackOp->hasOneUse()) {
+ auto extractSliceUser =
+ dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+ if (extractSliceUser &&
+ areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+ areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+ extractSliceUser.getSourceType().getRank() ==
+ extractSliceUser.getResultType().getRank()) {
+ auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+ unPackOp->getLoc(), unPackOp.getDest(),
+ extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+ extractSliceUser.getMixedStrides());
+ rewriter.replaceOpWithNewOp<UnPackOp>(
+ extractSliceUser, unPackOp.getSource(), newDest,
+ unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
+ unPackOp.getOuterDimsPerm());
+ return success();
+ }
+ }
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
return %unpack : tensor<7x?xi32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
+// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+ return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+ return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
More information about the Mlir-commits
mailing list