[Mlir-commits] [mlir] [mlir][SCF] Make producer tile + fuse fusion function customizable (PR #107883)
Quinn Dawkins
llvmlistbot at llvm.org
Mon Sep 9 09:08:34 PDT 2024
https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/107883
None
>From bdb8ad98c6f0b5461dbba8bce9846356e78e376b Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 9 Sep 2024 12:07:21 -0400
Subject: [PATCH] [mlir][SCF] Make producer tile + fuse fusion function
customizable
---
.../SCF/Transforms/TileUsingInterface.h | 35 ++++++++++++-------
.../SCF/Transforms/TileUsingInterface.cpp | 2 +-
2 files changed, 23 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..9a85c84dac8068 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
#define MLIR_DIALECT_SCF_TRANSFORMS_TILEUSINGINTERFACE_H
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
@@ -114,6 +115,20 @@ FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
TilingInterface op,
const SCFTilingOptions &options);
+/// Fuse the producer of the source of `candidateSliceOp` by computing the
+/// required slice of the producer in-place. Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
+/// value but does not delete the slice operation.
+struct SCFFuseProducerOfSliceResult {
+ OpResult origProducer; // Original untiled producer.
+ Value tiledAndFusedProducer; // Tile and fused producer value.
+ SmallVector<Operation *> tiledOps;
+};
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<LoopLikeOpInterface> loops);
+
/// Options used to control tile + fuse.
struct SCFTileAndFuseOptions {
/// The tiling options used to control the tiling of the consumer.
@@ -146,21 +161,15 @@ struct SCFTileAndFuseOptions {
fusionControlFn = controlFn;
return *this;
}
-};
-/// Fuse the producer of the source of `candidateSliceOp` by computing the
-/// required slice of the producer in-place. Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
-/// value but does not delete the slice operation.
-struct SCFFuseProducerOfSliceResult {
- OpResult origProducer; // Original untiled producer.
- Value tiledAndFusedProducer; // Tile and fused producer value.
- SmallVector<Operation *> tiledOps;
+ /// Customizable function implementing the fusion logic for the producer
+ /// of the given slice.
+ using FusionFnTy =
+ std::function<std::optional<scf::SCFFuseProducerOfSliceResult>(
+ RewriterBase &, tensor::ExtractSliceOp,
+ MutableArrayRef<LoopLikeOpInterface>)>;
+ FusionFnTy fusionFn = tileAndFuseProducerOfSlice;
};
-std::optional<SCFFuseProducerOfSliceResult>
-tileAndFuseProducerOfSlice(RewriterBase &rewriter,
- tensor::ExtractSliceOp candidateSliceOp,
- MutableArrayRef<LoopLikeOpInterface> loops);
/// Reconstruct the fused producer from within the tiled-and-fused code. Based
/// on the slice of the producer computed in place it is possible that within
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..1903c79fc92e6e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1389,7 +1389,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
- tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
+ options.fusionFn(rewriter, candidateSliceOp, loops);
if (!fusedResult)
continue;
More information about the Mlir-commits
mailing list