[Mlir-commits] [mlir] [mlir][linalg] Expose transform.fuse_into_containing_op helpers: NFC (PR #72473)

Quinn Dawkins llvmlistbot at llvm.org
Wed Nov 15 20:57:23 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/72473

This allows use of the various fusion helpers called by
`transform.fuse_into_containing_op` in other contexts.

>From 77734914e794c840d039cc46d0f0e4862661d327 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 15 Nov 2023 15:30:44 -0500
Subject: [PATCH] [mlir][linalg] Expose fuse_into_containing_op helpers: NFC

This allows use of the fusion implementations in
`transform.fuse_into_containing_op` in other contexts.
---
 .../Linalg/TransformOps/LinalgTransformOps.h  | 27 ++++++++++++++++
 .../TransformOps/LinalgTransformOps.cpp       | 31 +++++++------------
 2 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 12923663b3fb6ce6..f9f96ae0b7b6494b 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -59,6 +59,33 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    std::optional<ArrayAttr> mapping,
                    linalg::ForallTilingResult &tilingResult);
 
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+/// If tiled op has uses that are dominated by `containingOp`, return
+/// a new `containingOp` with results of the fused op appended to
+/// results of the `containingOp` or nullptr if there are no dominated uses.
+std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
+                           Operation *producerOp, Operation *containingOp);
+
+/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
+/// it is exactly the `containingOp`, otherwise bail.
+/// Then, find the first "extract" user of the tied block argument and tile it
+/// right before its "extract" use. The tiled op is fused under the
+/// `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+SmallVector<Operation *>
+tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+    RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+    Operation *containingOp);
+
+/// Find the first use of `producerOp` inside `containingOp` and fuse into
+/// the containing op by cloning the producer. Return nullptr if no such
+/// fusion opportunity exists.
+Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
+                                Operation *producerOp, Operation *containingOp);
+
 } // namespace transform
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162eaa..fe98dfbf287a8adc 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -628,15 +628,11 @@ static Operation *replaceForAllWithNewSignature(
   return newforallOp;
 }
 
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-/// If tiled op has uses that are dominated by `containingOp`, return
-/// a new `containingOp` with results of the fused op appended to
-/// results of the `containingOp` or nullptr if there are no dominated uses.
-static std::tuple<SmallVector<Operation *>, Operation *>
-tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
-                           Operation *producerOp, Operation *containingOp) {
+std::tuple<SmallVector<Operation *>, Operation *>
+mlir::transform::tileAndFuseFirstExtractUse(RewriterBase &rewriter,
+                                            Diagnostic &diag,
+                                            Operation *producerOp,
+                                            Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
@@ -710,14 +706,8 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
   return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
 }
 
-/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
-/// it is exactly the `containingOp`, otherwise bail.
-/// Then, find the first "extract" user of the tied block argument and tile it
-/// right before its "extract" use. The tiled op is fused under the
-/// `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static SmallVector<Operation *>
-tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+SmallVector<Operation *>
+mlir::transform::tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
     Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n");
@@ -819,9 +809,10 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   return tileAndFuseResult->tiledOps;
 }
 
-static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
-                                       Operation *producerOp,
-                                       Operation *containingOp) {
+Operation *mlir::transform::cloneAndFuseFirstUse(RewriterBase &rewriter,
+                                                 Diagnostic &diag,
+                                                 Operation *producerOp,
+                                                 Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n");
 
   // Gather all uses inside the containing op.



More information about the Mlir-commits mailing list