[Mlir-commits] [mlir] [mlir][SCF] Make producer tile + fuse fusion function customizable (PR #107883)

Quinn Dawkins llvmlistbot at llvm.org
Mon Sep 9 09:08:34 PDT 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/107883

None

>From bdb8ad98c6f0b5461dbba8bce9846356e78e376b Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 9 Sep 2024 12:07:21 -0400
Subject: [PATCH] [mlir][SCF] Make producer tile + fuse fusion function
 customizable

---
 .../SCF/Transforms/TileUsingInterface.h       | 35 ++++++++++++-------
 .../SCF/Transforms/TileUsingInterface.cpp     |  2 +-
 2 files changed, 23 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..9a85c84dac8068 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
 #define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
 
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
@@ -114,6 +115,20 @@ FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
                                         TilingInterface op,
                                         const SCFTilingOptions &options);
 
+/// Fuse the producer of the source of `candidateSliceOp` by computing the
+/// required slice of the producer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
+/// value but does not delete the slice operation.
+struct SCFFuseProducerOfSliceResult {
+  OpResult origProducer;       // Original untiled producer.
+  Value tiledAndFusedProducer; // Tile and fused producer value.
+  SmallVector<Operation *> tiledOps;
+};
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                           tensor::ExtractSliceOp candidateSliceOp,
+                           MutableArrayRef<LoopLikeOpInterface> loops);
+
 /// Options used to control tile + fuse.
 struct SCFTileAndFuseOptions {
   /// The tiling options used to control the tiling of the consumer.
@@ -146,21 +161,15 @@ struct SCFTileAndFuseOptions {
     fusionControlFn = controlFn;
     return *this;
   }
-};
 
-/// Fuse the producer of the source of `candidateSliceOp` by computing the
-/// required slice of the producer in-place.  Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
-/// value but does not delete the slice operation.
-struct SCFFuseProducerOfSliceResult {
-  OpResult origProducer;       // Original untiled producer.
-  Value tiledAndFusedProducer; // Tile and fused producer value.
-  SmallVector<Operation *> tiledOps;
+  /// Customizable function implementing the fusion logic for the producer
+  /// of the given slice.
+  using FusionFnTy =
+      std::function<std::optional<scf::SCFFuseProducerOfSliceResult>(
+          RewriterBase &, tensor::ExtractSliceOp,
+          MutableArrayRef<LoopLikeOpInterface>)>;
+  FusionFnTy fusionFn = tileAndFuseProducerOfSlice;
 };
-std::optional<SCFFuseProducerOfSliceResult>
-tileAndFuseProducerOfSlice(RewriterBase &rewriter,
-                           tensor::ExtractSliceOp candidateSliceOp,
-                           MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..1903c79fc92e6e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1389,7 +1389,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
     std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
-        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+        options.fusionFn(rewriter, candidateSliceOp, loops);
     if (!fusedResult)
       continue;
 



More information about the Mlir-commits mailing list