[Mlir-commits] [mlir] [mlir][linalg] Extend `FuseElementwiseOps` pattern to work with named ops (PR #144922)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 21 09:29:48 PDT 2025
https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/144922
>From c76a8ccd542376b2cf00e4fbcc1da3c38c1a1f8e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 11:02:38 -0500
Subject: [PATCH 01/13] Make fusion work on any LinalgOp
---
.../Dialect/Linalg/Transforms/Transforms.h | 4 +-
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++---------
2 files changed, 24 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c202..0420edba2b300 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -511,8 +511,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
-llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
- GenericOp consumer,
+llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
+ LinalgOp consumer,
OpOperand *fusedOperand);
/// Try to peel and canonicalize loop `op` and return the new result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3a57f368d4425..498563e605fef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
// of the fused producer & consumer after the fusion can still compute the
// bounds of the op.
static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
- GenericOp producer, GenericOp consumer,
+ LinalgOp producer, LinalgOp consumer,
ArrayRef<OpOperand *> opOperandsToIgnore) {
SmallVector<AffineMap> indexingMaps;
- SmallVector<GenericOp> ops = {producer, consumer};
+ SmallVector<LinalgOp> ops = {producer, consumer};
for (auto &op : ops) {
for (auto &opOperand : op->getOpOperands()) {
if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
@@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
- GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
+ LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
@@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
if (!fusedOperand)
return false;
- auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
- auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
+ auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());
// Check producer and consumer are generic ops.
if (!producer || !consumer)
@@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void generateFusedElementwiseOpRegion(
- RewriterBase &rewriter, GenericOp fusedOp,
+ RewriterBase &rewriter, LinalgOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
- auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
Block &producerBlock = producer->getRegion(0).front();
Block &consumerBlock = consumer->getRegion(0).front();
OpBuilder::InsertionGuard guard(rewriter);
- Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+ Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;
// 2. Add an index operation for every fused loop dimension and use the
@@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion(
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
// Sanity checks.
- assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+ assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
}
@@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
assert(areElementwiseOpsFusable(fusedOperand) &&
"expected elementwise operation pre-conditions to pass");
auto producerResult = cast<OpResult>(fusedOperand->get());
- auto producer = cast<GenericOp>(producerResult.getOwner());
- auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+ auto producer = cast<LinalgOp>(producerResult.getOwner());
+ auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
@@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// Generate the fused op.
+ // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
+ // fusedResultTypes, fusedInputOperands);
+ // fusedOp.setIndexingMapsAttr(idxMap);
+ // fusedOp.setIteratorTypesAttr(itTp);
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
- fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
- consumer.getIteratorTypes(),
- /*doc=*/nullptr,
- /*library_call=*/nullptr);
+ fusedOutputOperands, fusedIndexMaps,
+ consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
@@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
namespace {
/// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
public:
FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
PatternBenefit benefit = 1)
- : OpRewritePattern<GenericOp>(context, benefit),
+ : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
controlFn(std::move(fun)) {}
- LogicalResult matchAndRewrite(GenericOp genericOp,
+ LogicalResult matchAndRewrite(LinalgOp genericOp,
PatternRewriter &rewriter) const override {
// Find the first operand that is defined by another generic op on tensors.
for (OpOperand &opOperand : genericOp->getOpOperands()) {
@@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
rewriter.eraseOp(genericOp);
return success();
}
- return failure();
+ return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
}
private:
>From 20b25f3b4b75a67fcadb94720fb13b915ce1bc29 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 11:35:37 -0500
Subject: [PATCH 02/13] format and add test
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 12 ++++-------
.../Dialect/Linalg/fusion-elementwise.mlir | 21 +++++++++++++++++++
2 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 0b5e3d1b123b3..688244f44cbe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
/// * There is a chance that the implementation of the transformation does not
/// agree with the result of this method. This function gives a prediction based
/// on an optimized fusion.
-llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
- LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
+llvm::SmallDenseSet<int>
+mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
+ OpOperand *fusedOperand) {
llvm::SmallDenseSet<int> preservedProducerResults;
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
@@ -416,14 +417,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// Generate the fused op.
- // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
- // fusedResultTypes, fusedInputOperands);
- // fusedOp.setIndexingMapsAttr(idxMap);
- // fusedOp.setIteratorTypesAttr(itTp);
auto fusedOp = rewriter.create<GenericOp>(
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
- fusedOutputOperands, fusedIndexMaps,
- consumer.getIteratorTypesArray());
+ fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
if (!fusedOp.getShapesToLoopsMap()) {
// Fused op has invalid indexing maps. Typically this means something is off
// in the input, but going ahead here would result in verification errors.
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index bd9977f1410b9..db24d6d5f027a 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -59,3 +59,24 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.generic
>From 8e471a750a962feea17d99c27bf2bdb17a991ad1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 13:23:22 -0500
Subject: [PATCH 03/13] fix typo in test
---
mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index db24d6d5f027a..9b5f3d12f3d21 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -74,7 +74,7 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
-// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] :
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
>From d723913f901841e3f8b6ee7ee4b71ec2e66e30ab Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 13:47:52 -0500
Subject: [PATCH 04/13] add same test for other fusion pass
-linalg-fuse-elementwise-ops
---
.../Linalg/fusion-elementwise-ops.mlir | 21 +++++++++++++++++++
1 file changed, 21 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 66fc55fadf8fa..b581567cf57a7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1014,3 +1014,24 @@ module {
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
// CHECK: linalg.yield %[[T3]] : f32
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.generic
>From 5280b873e345c7976b8deee5f01cdba354d6df28 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 16:08:02 -0500
Subject: [PATCH 05/13] fix bug with no output bb args and add test
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 ++++++++++++
.../Dialect/Linalg/fusion-elementwise.mlir | 35 ++++++++++++++++++-
2 files changed, 57 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 688244f44cbe7..fc435b47f5977 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -222,8 +222,31 @@ static void generateFusedElementwiseOpRegion(
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
+
+ // Since some ops, like `linalg.map`, do not have block arguments for init operands
+ // then we first "generalize" the block by adding arguments for init operands when
+ // they aren't present. We detect this case by checking if
+ // `getOpOperandsMatchingBBargs() == getDpsInputOperands();
Block &producerBlock = producer->getRegion(0).front();
+ if (producer.getOpOperandsMatchingBBargs() ==
+ producer.getDpsInputOperands()) {
+ for (auto init : producer.getDpsInits()) {
+ Type bbType = isa<ShapedType>(init.getType())
+ ? cast<ShapedType>(init.getType()).getElementType()
+ : init.getType();
+ producerBlock.addArgument(bbType, producer.getLoc());
+ }
+ }
Block &consumerBlock = consumer->getRegion(0).front();
+ if (consumer.getOpOperandsMatchingBBargs() ==
+ consumer.getDpsInputOperands()) {
+ for (auto init : consumer.getDpsInits()) {
+ Type bbType = isa<ShapedType>(init.getType())
+ ? cast<ShapedType>(init.getType()).getElementType()
+ : init.getType();
+ consumerBlock.addArgument(bbType, consumer.getLoc());
+ }
+ }
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
IRMapping mapper;
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 9b5f3d12f3d21..18ca8b42fa79c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -79,4 +79,37 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
-// CHECK-NOT: linalg.generic
+// CHECK-NOT: linalg.map
+
+// -----
+
+func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+ %init = tensor.empty() : tensor<8xi1>
+ %initf = tensor.empty() : tensor<8xf32>
+ %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
+ (%in0 : f32, %in1 : f32) {
+ %cmp = arith.cmpf olt, %in0, %in1 : f32
+ linalg.yield %cmp : i1
+ }
+ %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
+ return %3 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops_mixed_types
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
+// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
+// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
+// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
+// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
+// CHECK-NEXT: linalg.yield %[[RES]]
+// CHECK-NOT: linalg.map
+
>From c2f52bc4154b62281bfcd8521154faf81e04c1f1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:11:45 -0500
Subject: [PATCH 06/13] add linalg.elementwise test
---
.../Dialect/Linalg/fusion-elementwise.mlir | 28 +++++++++++++++++--
1 file changed, 26 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 18ca8b42fa79c..2f9011cd5e52b 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -65,8 +65,8 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
%fill = tensor.empty() : tensor<8xf32>
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
- %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
- return %mapped_65 : tensor<8xf32>
+ %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ return %sqrt : tensor<8xf32>
}
// CHECK-LABEL: func @map_ops
@@ -113,3 +113,27 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
// CHECK-NEXT: linalg.yield %[[RES]]
// CHECK-NOT: linalg.map
+// -----
+
+func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add = linalg.elementwise
+ kind=#linalg.elementwise_kind<add>
+ ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32>
+ %wqrt = linalg.elementwise
+ kind=#linalg.elementwise_kind<sqrt>
+ ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32>
+ return %wqrt : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @elementwise_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.map
>From 8d2e8e0be55a1451e8b9774dddf9199158c98b2d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:13:22 -0500
Subject: [PATCH 07/13] fix formatting
---
.../lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index fc435b47f5977..c1fc003d3f05d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -223,10 +223,10 @@ static void generateFusedElementwiseOpRegion(
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
// Build the region of the fused op.
- // Since some ops, like `linalg.map`, do not have block arguments for init operands
- // then we first "generalize" the block by adding arguments for init operands when
- // they aren't present. We detect this case by checking if
- // `getOpOperandsMatchingBBargs() == getDpsInputOperands();
+ // Since some ops, like `linalg.map`, do not have block arguments for init
+ // operands then we first "generalize" the block by adding arguments for init
+ // operands when they aren't present. We detect this case by checking if
+ // `getOpOperandsMatchingBBargs() == getDpsInputOperands()
Block &producerBlock = producer->getRegion(0).front();
if (producer.getOpOperandsMatchingBBargs() ==
producer.getDpsInputOperands()) {
>From 58582bfd75576e2ea089949207e25727cea7ca69 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:23:33 -0500
Subject: [PATCH 08/13] use getElementTypeOrSelf for cleanup
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index c1fc003d3f05d..6ec13e33055ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -230,22 +230,16 @@ static void generateFusedElementwiseOpRegion(
Block &producerBlock = producer->getRegion(0).front();
if (producer.getOpOperandsMatchingBBargs() ==
producer.getDpsInputOperands()) {
- for (auto init : producer.getDpsInits()) {
- Type bbType = isa<ShapedType>(init.getType())
- ? cast<ShapedType>(init.getType()).getElementType()
- : init.getType();
- producerBlock.addArgument(bbType, producer.getLoc());
- }
+ for (auto init : producer.getDpsInits())
+ producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
+ producer.getLoc());
}
Block &consumerBlock = consumer->getRegion(0).front();
if (consumer.getOpOperandsMatchingBBargs() ==
consumer.getDpsInputOperands()) {
- for (auto init : consumer.getDpsInits()) {
- Type bbType = isa<ShapedType>(init.getType())
- ? cast<ShapedType>(init.getType()).getElementType()
- : init.getType();
- consumerBlock.addArgument(bbType, consumer.getLoc());
- }
+ for (auto init : consumer.getDpsInits())
+ consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
+ consumer.getLoc());
}
OpBuilder::InsertionGuard guard(rewriter);
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
>From cf67ab67bf4da9f8c65137fd627b6ba2d8da0ebb Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:30:12 -0500
Subject: [PATCH 09/13] switch elementwise test to broadcast version
---
.../Dialect/Linalg/fusion-elementwise.mlir | 20 +++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 2f9011cd5e52b..575f21b8f09f9 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -115,21 +115,25 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
// -----
-func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
- %fill = tensor.empty() : tensor<8xf32>
+#identity = affine_map<(d0, d1) -> (d0, d1)>
+#bcast = affine_map<(d0, d1) -> (d0)>
+func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tensor<8x10xf32> {
+ %fill = tensor.empty() : tensor<8x10xf32>
%add = linalg.elementwise
kind=#linalg.elementwise_kind<add>
- ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32>
- %wqrt = linalg.elementwise
+ indexing_maps = [#bcast, #identity, #identity]
+ ins(%in1, %in2: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) -> tensor<8x10xf32>
+ %sqrt = linalg.elementwise
kind=#linalg.elementwise_kind<sqrt>
- ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32>
- return %wqrt : tensor<8xf32>
+ indexing_maps = [#identity, #identity]
+ ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>) -> tensor<8x10xf32>
+ return %sqrt : tensor<8x10xf32>
}
// CHECK-LABEL: func @elementwise_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
-// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
>From 7d402c1f75a09d5b9cda01fd49ae287928a47364 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Fri, 20 Jun 2025 14:23:35 -0500
Subject: [PATCH 10/13] remove block args that were added (hacky)
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 22 ++++++++++++++-----
1 file changed, 17 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6ec13e33055ce..c3b5765a5c4ad 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -226,17 +226,21 @@ static void generateFusedElementwiseOpRegion(
// Since some ops, like `linalg.map`, do not have block arguments for init
// operands then we first "generalize" the block by adding arguments for init
// operands when they aren't present. We detect this case by checking if
- // `getOpOperandsMatchingBBargs() == getDpsInputOperands()
+ // `getOpOperandsMatchingBBargs() == getDpsInputOperands()`.
+ // TODO: This is hacky and should not be merged. Keeping for now for testing
+ // purposes in the meantime, but need a better way
Block &producerBlock = producer->getRegion(0).front();
- if (producer.getOpOperandsMatchingBBargs() ==
- producer.getDpsInputOperands()) {
+ bool addOutputArgsProducer =
+ producer.getOpOperandsMatchingBBargs() == producer.getDpsInputOperands();
+ if (addOutputArgsProducer) {
for (auto init : producer.getDpsInits())
producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
producer.getLoc());
}
Block &consumerBlock = consumer->getRegion(0).front();
- if (consumer.getOpOperandsMatchingBBargs() ==
- consumer.getDpsInputOperands()) {
+ bool addOutputArgsConsumer =
+ consumer.getOpOperandsMatchingBBargs() == consumer.getDpsInputOperands();
+ if (addOutputArgsConsumer) {
for (auto init : consumer.getDpsInits())
consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
consumer.getLoc());
@@ -350,6 +354,14 @@ static void generateFusedElementwiseOpRegion(
// Sanity checks.
assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
"Ill-formed GenericOp region");
+ // Erase added args in case that the ops are still live after fusion.
+ // TODO: Remove along with hacky code above.
+ if (addOutputArgsProducer)
+ producerBlock.eraseArguments(producer.getNumDpsInputs(),
+ producer.getNumDpsInits());
+ if (addOutputArgsConsumer)
+ consumerBlock.eraseArguments(consumer.getNumDpsInputs(),
+ consumer.getNumDpsInits());
}
FailureOr<mlir::linalg::ElementwiseOpFusionResult>
>From b1d15b2822953882376661a1b66ec8adc5cc01a1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 21 Jun 2025 11:19:06 -0500
Subject: [PATCH 11/13] add requested tests
---
.../Dialect/Linalg/fusion-elementwise.mlir | 63 +++++++++++++++++++
1 file changed, 63 insertions(+)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 575f21b8f09f9..8aa6974d5f0e4 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -141,3 +141,66 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
// CHECK-NEXT: linalg.yield %[[SQRT]]
// CHECK-NOT: linalg.map
+
+// -----
+
+func.func @map_multi_ops(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> {
+ %fill = tensor.empty() : tensor<8xf32>
+ %add_exp = linalg.map ins(%arg0, %arg1: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+ (%in0 : f32, %in1 : f32) {
+ %add = arith.addf %in0, %in1 : f32
+ %exp = math.exp %add : f32
+ linalg.yield %exp : f32
+ }
+ %sqrt_mul = linalg.map ins(%add_exp, %arg2 : tensor<8xf32>, tensor<8xf32>) outs(%fill : tensor<8xf32>)
+ (%in0 : f32, %in1 : f32) {
+ %sqrt = math.sqrt %in0 : f32
+ %mul = arith.mulf %sqrt, %in1 : f32
+ linalg.yield %mul : f32
+ }
+ return %sqrt_mul : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_multi_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[EXP:.*]] = math.exp %[[ADD]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[EXP]]
+// CHECK-NEXT: %[[MUL:.*]] = arith.mulf %[[SQRT]], %[[IN2]]
+// CHECK-NEXT: linalg.yield %[[MUL]]
+// CHECK-NOT: linalg.map
+
+// -----
+
+#identity = affine_map<(d0, d1) -> (d0, d1)>
+#bcast = affine_map<(d0, d1) -> (d0)>
+func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tensor<8x10xf32> {
+ %fill = tensor.empty() : tensor<8x10xf32>
+ %add = linalg.generic
+ {indexing_maps = [#bcast, #identity, #identity], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0, %arg1: tensor<8xf32>, tensor<8x10xf32>) outs(%fill: tensor<8x10xf32>) {
+ ^bb0(%in0: f32, %in1: f32, %out: f32):
+ %add = arith.addf %in0, %in1 : f32
+ linalg.yield %add : f32
+ } -> tensor<8x10xf32>
+ %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8x10xf32>) outs(%fill : tensor<8x10xf32>)
+ return %sqrt : tensor<8x10xf32>
+}
+
+// CHECK-LABEL: func @map_genric_ops
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
+// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
+// CHECK-NEXT: linalg.yield %[[SQRT]]
+// CHECK-NOT: linalg.map
>From 459bcb47feea8fa75771ec841eadcad96336a430 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 21 Jun 2025 11:24:29 -0500
Subject: [PATCH 12/13] add checks for nontrivial map cases
---
mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 6 ++++++
.../lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 7 ++++++-
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 8aa6974d5f0e4..d4b25eb4be691 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -130,11 +130,14 @@ func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8x10xf32>) -> tenso
return %sqrt : tensor<8x10xf32>
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @elementwise_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
@@ -193,11 +196,14 @@ func.func @map_genric_ops(%arg0: tensor<8xf32>, %arg1: tensor<8x10xf32>) -> tens
return %sqrt : tensor<8x10xf32>
}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-LABEL: func @map_genric_ops
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8x10xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x10xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]], #[[MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index cb215197253bb..6b9abd34b7781 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -151,9 +151,14 @@ struct TestLinalgElementwiseFusion
MLIRContext *context = &this->getContext();
func::FuncOp funcOp = this->getOperation();
+ auto controlFn = [](OpOperand *operand) {
+ auto owner = cast<linalg::LinalgOp>(operand->getOwner());
+ auto producer = cast<linalg::LinalgOp>(operand->get().getDefiningOp());
+ return (linalg::isElementwise(owner) && linalg::isElementwise(producer)) && (!isa<linalg::BroadcastOp>(producer) && !isa<linalg::BroadcastOp>(owner));
+ };
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
- auto controlFn = [](OpOperand *operand) { return true; };
+ // auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
>From f7e164beebdc8195b6781c2a8ce8bc1bea7757cd Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Sat, 21 Jun 2025 11:29:26 -0500
Subject: [PATCH 13/13] revert unintended change
---
.../lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 7 +------
1 file changed, 1 insertion(+), 6 deletions(-)
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
index 6b9abd34b7781..cb215197253bb 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
@@ -151,14 +151,9 @@ struct TestLinalgElementwiseFusion
MLIRContext *context = &this->getContext();
func::FuncOp funcOp = this->getOperation();
- auto controlFn = [](OpOperand *operand) {
- auto owner = cast<linalg::LinalgOp>(operand->getOwner());
- auto producer = cast<linalg::LinalgOp>(operand->get().getDefiningOp());
- return (linalg::isElementwise(owner) && linalg::isElementwise(producer)) && (!isa<linalg::BroadcastOp>(producer) && !isa<linalg::BroadcastOp>(owner));
- };
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
- // auto controlFn = [](OpOperand *operand) { return true; };
+ auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
More information about the Mlir-commits
mailing list