[Mlir-commits] [mlir] [NFC] Add allow Insert/extract slice option to pack/unpack op (PR #117340)

Zhuoran Yin llvmlistbot at llvm.org
Fri Nov 22 08:02:57 PST 2024


https://github.com/jerryyin created https://github.com/llvm/llvm-project/pull/117340

This PR adds default option below. The new options will come as default to true and not change the original lowering behavior of pack and unpack op.
 - allowInsertSliceLowering to packOp (with default = true)
 - allowExtractSliceLowering to unPackOp (with default = true)

The motivation of the PR is finer granular control of the lowering of pack and unpack Ops. This is useful in particular when we want to guarantee that there's no additional insertslice and extractslice that interfere with tiling. With the original lowering pipeline, packOp and unPackOp may be lowered to insertslice and extractslice when the high dimensions are unit dimensions and no transpose is invovled. Under such circumstances, such insert and extract slice ops will block consumer fusion for TileAndFuse pipeline. With this PR, we will be able to disable such lowering path and allow consumer fusion to go through as expected.

>From 012f6d46ff6c187470d6ca102be513e7a5a78a21 Mon Sep 17 00:00:00 2001
From: jerryyin <zhuoryin at amd.com>
Date: Fri, 22 Nov 2024 15:52:12 +0000
Subject: [PATCH] [NFC] Add allowInsertSliceLowering to packOp and
 allowExtractSliceLowering to UnPackOp

---
 .../Linalg/TransformOps/LinalgTransformOps.td        |  6 ++++--
 .../mlir/Dialect/Linalg/Transforms/Transforms.h      |  8 +++++---
 .../Linalg/TransformOps/LinalgTransformOps.cpp       |  8 ++++++--
 mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp    | 12 +++++++-----
 4 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index e3084530bd11b5..ea96da77b6c331 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -548,7 +548,8 @@ def LowerPackOp : Op<Transform_Dialect, "structured.lower_pack", [
     Return handles to the newly produced pad, expand_shape and transpose ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.pack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$allowInsertSliceLowering);
   let results = (outs Transform_ConcreteOpType<"tensor.pad">:$pad_op,
                       Transform_ConcreteOpType<"tensor.expand_shape">:$expand_shape_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op);
@@ -588,7 +589,8 @@ def LowerUnPackOp : Op<Transform_Dialect, "structured.lower_unpack", [
     Return handles to the newly produced empty, transpose, collapse_shape and extract_slice ops.
   }];
 
-  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target);
+  let arguments = (ins Transform_ConcreteOpType<"tensor.unpack">:$target,
+                       DefaultValuedAttr<BoolAttr, "true">:$allowExtractSliceLowering);
   let results = (outs Transform_ConcreteOpType<"tensor.empty">:$empty_op,
                       Transform_ConcreteOpType<"linalg.transpose">:$transpose_op,
                       Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op,
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 51967f83fee377..fd27e7929764d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1121,7 +1121,8 @@ struct LowerPackResult {
 
 /// Rewrite pack as pad + reshape + transpose.
 FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter,
-                                     tensor::PackOp packOp);
+                                     tensor::PackOp packOp,
+                                     bool allowInsertSliceLowering = true);
 
 struct LowerUnPackOpResult {
   tensor::EmptyOp emptyOp;
@@ -1131,8 +1132,9 @@ struct LowerUnPackOpResult {
 };
 
 /// Rewrite pack as empty + transpose + reshape + extract_slice.
-FailureOr<LowerUnPackOpResult> lowerUnPack(RewriterBase &rewriter,
-                                           tensor::UnPackOp unPackOp);
+FailureOr<LowerUnPackOpResult>
+lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+            bool allowExtractSliceLowering = true);
 
 /// Struct to hold the result of a `pack` call.
 struct PackResult {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ada80deacfdbfe..5117a5c58c381d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1171,7 +1171,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerPackResult> res = lowerPack(rewriter, target);
+  bool allowInsertSliceLowering = getAllowInsertSliceLowering();
+  FailureOr<LowerPackResult> res =
+      lowerPack(rewriter, target, allowInsertSliceLowering);
   if (failed(res)) {
     return mlir::emitSilenceableFailure(target->getLoc())
            << "cannot lower to pad + expand + transpose";
@@ -1191,7 +1193,9 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
     transform::ApplyToEachResultList &transformResults,
     transform::TransformState &state) {
   rewriter.setInsertionPoint(target);
-  FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
+  bool allowExtractSliceLowering = getAllowExtractSliceLowering();
+  FailureOr<LowerUnPackOpResult> res =
+      lowerUnPack(rewriter, target, allowExtractSliceLowering);
   if (failed(res)) {
     DiagnosedSilenceableFailure diag =
         emitSilenceableError()
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index d92543d7264625..0717dad4c2852f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -217,7 +217,8 @@ struct PackedOperandsDimList {
 } // namespace
 
 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
-                                             tensor::PackOp packOp) {
+                                             tensor::PackOp packOp,
+                                             bool allowInsertSliceLowering) {
   // 1. Filter out NYI cases.
   auto packedTensorType =
       cast<RankedTensorType>(packOp->getResultTypes().front());
@@ -295,7 +296,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
       llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
       DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
 
-  if (packOp.isLikePad()) {
+  if (allowInsertSliceLowering && packOp.isLikePad()) {
     // Pack ops which operate as simple pads may not produce legal
     // tensor.insert_slice operations when the packed type does not rank reduce
     // to the padded type.
@@ -351,8 +352,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
   return LowerPackResult{padOp, reshapeOp, transposeOp};
 }
 
-FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
-                                                   tensor::UnPackOp unPackOp) {
+FailureOr<LowerUnPackOpResult>
+linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
+                    bool allowExtractSliceLowering) {
   Location loc = unPackOp->getLoc();
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unPackOp);
@@ -362,7 +364,7 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
 
   OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
   auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
-  if (unPackOp.isLikeUnPad()) {
+  if (allowExtractSliceLowering && unPackOp.isLikeUnPad()) {
     // This unpack is just a plain unpad.
     // Just extract the slice from the higher ranked tensor.
     ArrayRef<int64_t> destShape = destTensorType.getShape();



More information about the Mlir-commits mailing list