[Mlir-commits] [mlir] Generalize fusion (PR #144922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 09:22:20 PDT 2025
https://github.com/srcarroll created https://github.com/llvm/llvm-project/pull/144922
None
>From c76a8ccd542376b2cf00e4fbcc1da3c38c1a1f8e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 11:02:38 -0500
Subject: [PATCH] Make fusion work on any LinalgOp
---
.../Dialect/Linalg/Transforms/Transforms.h | 4 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++---------
2 files changed, 24 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c202..0420edba2b300 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -511,8 +511,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
-llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
- GenericOp consumer,
+llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
+ LinalgOp consumer,
OpOperand *fusedOperand);
/// Try to peel and canonicalize loop `op` and return the new result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3a57f368d4425..498563e605fef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
- GenericOp producer, GenericOp consumer,
+ LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;
- SmallVector<GenericOp> ops = {producer, consumer};
+ SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
@@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
- GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
+ LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
@@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;
- auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
- auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
+ auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());
// Check producer and consumer are generic ops.
if (!producer || !consumer)
@@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
- RewriterBase &rewriter, GenericOp fusedOp,
+ RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
- auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
OpBuilder::InsertionGuard guard(rewriter);
- Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+ Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;
// 2. Add an index operation for every fused loop dimension and use the
@@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
// Sanity checks.
- assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+ assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}
@@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
- auto producer = cast<GenericOp>(producerResult.getOwner());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(producerResult.getOwner());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
@@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// Generate the fused op.
+ // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
+ // fusedResultTypes, fusedInputOperands);
+ // fusedOp.setIndexingMapsAttr(idxMap);
+ // fusedOp.setIteratorTypesAttr(itTp);
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
- fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.getIteratorTypes(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
+ fusedOutputOperands, fusedIndexMaps,
+ consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
@@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
@@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
- return failure();
+ return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}
private:
More information about the Mlir-commits
mailing list