[Mlir-commits] [mlir] 1407f5b - [mlir] Canonicalize extract_slice(unpack) (#133777)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 1 11:52:03 PDT 2025
Author: Max191
Date: 2025-04-01T14:51:58-04:00
New Revision: 1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf
URL: https://github.com/llvm/llvm-project/commit/1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf
DIFF: https://github.com/llvm/llvm-project/commit/1407f5bee9aa8e2a8a4fcab63ab0a3030a8b0dcf.diff
LOG: [mlir] Canonicalize extract_slice(unpack) (#133777)
Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a
`linalg.unpack` with reduced dest sizes. This will only happen when the
unpack op's only user is a non rank-reducing slice with zero offset and
unit strides.
---------
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
Signed-off-by: Max Dawkins <maxdawkins19 at gmail.com>
Co-authored-by: Max Dawkins <maxdawkins19 at gmail.com>
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index ff89ead59981c..d6b093c5fb86b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
@@ -5243,6 +5244,29 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
return success();
}
+ /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
+ if (unPackOp->hasOneUse()) {
+ auto extractSliceUser =
+ dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
+ if (extractSliceUser &&
+ areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
+ areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
+ extractSliceUser.getSourceType().getRank() ==
+ extractSliceUser.getResultType().getRank()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unPackOp);
+ auto newDest = rewriter.create<tensor::ExtractSliceOp>(
+ unPackOp->getLoc(), unPackOp.getDest(),
+ extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
+ extractSliceUser.getMixedStrides());
+ rewriter.modifyOpInPlace(unPackOp, [&]() {
+ unPackOp.setDpsInitOperand(0, newDest);
+ unPackOp.getResult().setType(newDest.getType());
+ });
+ rewriter.replaceOp(extractSliceUser, unPackOp);
+ return success();
+ }
+ }
// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index f99491c25d832..86cb8f58abe02 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
return %unpack : tensor<7x?xi32>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// linalg.unpack + tensor.extract_slice
+//===----------------------------------------------------------------------===//
+
+func.func @fold_extract_slice_into_unpack(
+ %src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
+) -> tensor<28x28x?xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1, 2]
+ inner_dims_pos = [1, 2]
+ inner_tiles = [16, 16]
+ into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
+ return %extracted_slice : tensor<28x28x?xf32>
+}
+
+// CHECK-LABEL: func @fold_extract_slice_into_unpack
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
+// CHECK-SAME: %[[SIZE:.+]]: index
+// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
+// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST_SLICE]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_rank_reducing(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
+ return %extracted_slice : tensor<28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
+
+// -----
+
+func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
+ %src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
+) -> tensor<28x28xf32> {
+ %unpack = linalg.unpack %src
+ outer_dims_perm = [0, 1]
+ inner_dims_pos = [1]
+ inner_tiles = [16]
+ into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
+ %extracted_slice = tensor.extract_slice %unpack
+ [0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
+ return %extracted_slice : tensor<28x28xf32>
+}
+
+// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
+// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
+// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
+// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
+// CHECK-SAME: into %[[DEST]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
+// CHECK: return %[[SLICE]]
More information about the Mlir-commits
mailing list