[Mlir-commits] [mlir] [mlir][linalg] Elementwise fusion for any `LinalgOp` (PR #144922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 16:13:38 PDT 2025


https://github.com/srcarroll updated https://github.com/llvm/llvm-project/pull/144922

>From c76a8ccd542376b2cf00e4fbcc1da3c38c1a1f8e Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 11:02:38 -0500
Subject: [PATCH 1/7] Make fusion work on any LinalgOp

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  4 +-
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 42 ++++++++++---------
 2 files changed, 24 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c202..0420edba2b300 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -511,8 +511,8 @@ fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand);
 /// * There is a chance that the implementation of the transformation does not
 /// agree with the result of this method. This function gives a prediction based
 /// on an optimized fusion.
-llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer,
-                                                     GenericOp consumer,
+llvm::SmallDenseSet<int> getPreservedProducerResults(LinalgOp producer,
+                                                     LinalgOp consumer,
                                                      OpOperand *fusedOperand);
 
 /// Try to peel and canonicalize loop `op` and return the new result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3a57f368d4425..498563e605fef 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -75,11 +75,11 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
 // of the fused producer & consumer after the fusion can still compute the
 // bounds of the op.
 static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
-    GenericOp producer, GenericOp consumer,
+    LinalgOp producer, LinalgOp consumer,
     ArrayRef<OpOperand *> opOperandsToIgnore) {
   SmallVector<AffineMap> indexingMaps;
 
-  SmallVector<GenericOp> ops = {producer, consumer};
+  SmallVector<LinalgOp> ops = {producer, consumer};
   for (auto &op : ops) {
     for (auto &opOperand : op->getOpOperands()) {
       if (llvm::is_contained(opOperandsToIgnore, &opOperand)) {
@@ -108,7 +108,7 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
 /// agree with the result of this method. This function gives a prediction based
 /// on an optimized fusion.
 llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
-    GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) {
+    LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
   llvm::SmallDenseSet<int> preservedProducerResults;
   llvm::SmallVector<OpOperand *> opOperandsToIgnore;
 
@@ -138,8 +138,8 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
   if (!fusedOperand)
     return false;
 
-  auto producer = fusedOperand->get().getDefiningOp<GenericOp>();
-  auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner());
+  auto producer = fusedOperand->get().getDefiningOp<LinalgOp>();
+  auto consumer = dyn_cast<LinalgOp>(fusedOperand->getOwner());
 
   // Check producer and consumer are generic ops.
   if (!producer || !consumer)
@@ -213,16 +213,16 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
 /// Generate the region of the fused tensor operation. The region of the fused
 /// op must be empty.
 static void generateFusedElementwiseOpRegion(
-    RewriterBase &rewriter, GenericOp fusedOp,
+    RewriterBase &rewriter, LinalgOp fusedOp,
     AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
     unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
-  auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
-  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+  auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
+  auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
   // Build the region of the fused op.
   Block &producerBlock = producer->getRegion(0).front();
   Block &consumerBlock = consumer->getRegion(0).front();
   OpBuilder::InsertionGuard guard(rewriter);
-  Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion());
+  Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
   IRMapping mapper;
 
   // 2. Add an index operation for every fused loop dimension and use the
@@ -329,7 +329,7 @@ static void generateFusedElementwiseOpRegion(
   rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
 
   // Sanity checks.
-  assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() &&
+  assert(fusedBlock->getNumArguments() == fusedOp->getNumOperands() &&
          "Ill-formed GenericOp region");
 }
 
@@ -339,8 +339,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   assert(areElementwiseOpsFusable(fusedOperand) &&
          "expected elementwise operation pre-conditions to pass");
   auto producerResult = cast<OpResult>(fusedOperand->get());
-  auto producer = cast<GenericOp>(producerResult.getOwner());
-  auto consumer = cast<GenericOp>(fusedOperand->getOwner());
+  auto producer = cast<LinalgOp>(producerResult.getOwner());
+  auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
   // TODO: allow fusing the producer of an output operand.
   assert(consumer.isDpsInput(fusedOperand) &&
          "expected producer of input operand");
@@ -415,12 +415,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   }
 
   // Generate the fused op.
+  // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
+  //                              fusedResultTypes, fusedInputOperands);
+  // fusedOp.setIndexingMapsAttr(idxMap);
+  // fusedOp.setIteratorTypesAttr(itTp);
   auto fusedOp = rewriter.create<GenericOp>(
       consumer.getLoc(), fusedResultTypes, fusedInputOperands,
-      fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps),
-      consumer.getIteratorTypes(),
-      /*doc=*/nullptr,
-      /*library_call=*/nullptr);
+      fusedOutputOperands, fusedIndexMaps,
+      consumer.getIteratorTypesArray());
   if (!fusedOp.getShapesToLoopsMap()) {
     // Fused op has invalid indexing maps. Typically this means something is off
     // in the input, but going ahead here would result in verification errors.
@@ -459,14 +461,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
 
 namespace {
 /// Patterns to fuse a generic op, with the producer of its operands.
-class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
+class FuseElementwiseOps : public OpInterfaceRewritePattern<LinalgOp> {
 public:
   FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun,
                      PatternBenefit benefit = 1)
-      : OpRewritePattern<GenericOp>(context, benefit),
+      : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
         controlFn(std::move(fun)) {}
 
-  LogicalResult matchAndRewrite(GenericOp genericOp,
+  LogicalResult matchAndRewrite(LinalgOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
     for (OpOperand &opOperand : genericOp->getOpOperands()) {
@@ -493,7 +495,7 @@ class FuseElementwiseOps : public OpRewritePattern<GenericOp> {
       rewriter.eraseOp(genericOp);
       return success();
     }
-    return failure();
+    return rewriter.notifyMatchFailure(genericOp, "no fusable operands");
   }
 
 private:

>From 20b25f3b4b75a67fcadb94720fb13b915ce1bc29 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 11:35:37 -0500
Subject: [PATCH 2/7] format and add test

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 12 ++++-------
 .../Dialect/Linalg/fusion-elementwise.mlir    | 21 +++++++++++++++++++
 2 files changed, 25 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 0b5e3d1b123b3..688244f44cbe7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
 /// * There is a chance that the implementation of the transformation does not
 /// agree with the result of this method. This function gives a prediction based
 /// on an optimized fusion.
-llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
-    LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
+llvm::SmallDenseSet<int>
+mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
+                                          OpOperand *fusedOperand) {
   llvm::SmallDenseSet<int> preservedProducerResults;
   llvm::SmallVector<OpOperand *> opOperandsToIgnore;
 
@@ -416,14 +417,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   }
 
   // Generate the fused op.
-  // auto fusedOp = cloneWithoutRegions(rewriter, consumer,
-  //                              fusedResultTypes, fusedInputOperands);
-  // fusedOp.setIndexingMapsAttr(idxMap);
-  // fusedOp.setIteratorTypesAttr(itTp);
   auto fusedOp = rewriter.create<GenericOp>(
       consumer.getLoc(), fusedResultTypes, fusedInputOperands,
-      fusedOutputOperands, fusedIndexMaps,
-      consumer.getIteratorTypesArray());
+      fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
   if (!fusedOp.getShapesToLoopsMap()) {
     // Fused op has invalid indexing maps. Typically this means something is off
     // in the input, but going ahead here would result in verification errors.
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index bd9977f1410b9..db24d6d5f027a 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -59,3 +59,24 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
 //       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
 //  CHECK-SAME:       outs(%[[EMPTY]] :
 //   CHECK-NOT:   linalg.generic
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+    %fill = tensor.empty() : tensor<8xf32>
+    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+    %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+    return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]
+//  CHECK-NEXT:     linalg.yield %[[SQRT]] 
+//   CHECK-NOT:   linalg.generic

>From 8e471a750a962feea17d99c27bf2bdb17a991ad1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 13:23:22 -0500
Subject: [PATCH 3/7] fix typo in test

---
 mlir/test/Dialect/Linalg/fusion-elementwise.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index db24d6d5f027a..9b5f3d12f3d21 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -74,7 +74,7 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
 //       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
 //       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
-//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] :
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
 //  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
 //  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
 //  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]

>From d723913f901841e3f8b6ee7ee4b71ec2e66e30ab Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 13:47:52 -0500
Subject: [PATCH 4/7] add same test for other fusion pass
 -linalg-fuse-elementwise-ops

---
 .../Linalg/fusion-elementwise-ops.mlir        | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 66fc55fadf8fa..b581567cf57a7 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -1014,3 +1014,24 @@ module {
 //   CHECK-DAG:     %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
 //       CHECK:     linalg.yield %[[T3]] : f32
 //       CHECK:   return %[[GENERIC]]
+
+// -----
+
+func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+    %fill = tensor.empty() : tensor<8xf32>
+    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
+    %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+    return %mapped_65 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]
+//  CHECK-NEXT:     linalg.yield %[[SQRT]] 
+//   CHECK-NOT:   linalg.generic

>From 5280b873e345c7976b8deee5f01cdba354d6df28 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 16:08:02 -0500
Subject: [PATCH 5/7] fix bug with no output bb args and add test

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 23 ++++++++++++
 .../Dialect/Linalg/fusion-elementwise.mlir    | 35 ++++++++++++++++++-
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 688244f44cbe7..fc435b47f5977 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -222,8 +222,31 @@ static void generateFusedElementwiseOpRegion(
   auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
   auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
   // Build the region of the fused op.
+
+  // Since some ops, like `linalg.map`, do not have block arguments for init operands
+  // then we first "generalize" the block by adding arguments for init operands when
+  // they aren't present. We detect this case by checking if
+  // `getOpOperandsMatchingBBargs() == getDpsInputOperands(); 
   Block &producerBlock = producer->getRegion(0).front();
+  if (producer.getOpOperandsMatchingBBargs() ==
+      producer.getDpsInputOperands()) {
+    for (auto init : producer.getDpsInits()) {
+      Type bbType = isa<ShapedType>(init.getType())
+                        ? cast<ShapedType>(init.getType()).getElementType()
+                        : init.getType();
+      producerBlock.addArgument(bbType, producer.getLoc());
+    }
+  }
   Block &consumerBlock = consumer->getRegion(0).front();
+  if (consumer.getOpOperandsMatchingBBargs() ==
+      consumer.getDpsInputOperands()) {
+    for (auto init : consumer.getDpsInits()) {
+      Type bbType = isa<ShapedType>(init.getType())
+                        ? cast<ShapedType>(init.getType()).getElementType()
+                        : init.getType();
+      consumerBlock.addArgument(bbType, consumer.getLoc());
+    }
+  }
   OpBuilder::InsertionGuard guard(rewriter);
   Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
   IRMapping mapper;
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 9b5f3d12f3d21..18ca8b42fa79c 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -79,4 +79,37 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
 //  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
 //  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]
 //  CHECK-NEXT:     linalg.yield %[[SQRT]] 
-//   CHECK-NOT:   linalg.generic
+//   CHECK-NOT:   linalg.map
+
+// -----
+
+func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
+  %init = tensor.empty() : tensor<8xi1>
+  %initf = tensor.empty() : tensor<8xf32>
+  %0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+  %1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
+  %2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
+    (%in0 : f32, %in1 : f32) {
+      %cmp = arith.cmpf olt, %in0, %in1 : f32
+      linalg.yield %cmp : i1
+  }
+  %3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>) 
+  return %3 : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @map_ops_mixed_types
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[EXP0:.*]] = math.exp %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT0:.*]] = math.sqrt %[[IN0]]
+//  CHECK-NEXT:     %[[EXP1:.*]] = math.exp %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT1:.*]] = math.sqrt %[[IN0]]
+//  CHECK-NEXT:     %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
+//  CHECK-NEXT:     %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
+//  CHECK-NEXT:     linalg.yield %[[RES]] 
+//   CHECK-NOT:   linalg.map
+

>From c2f52bc4154b62281bfcd8521154faf81e04c1f1 Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:11:45 -0500
Subject: [PATCH 6/7] add linalg.elementwise test

---
 .../Dialect/Linalg/fusion-elementwise.mlir    | 28 +++++++++++++++++--
 1 file changed, 26 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
index 18ca8b42fa79c..2f9011cd5e52b 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise.mlir
@@ -65,8 +65,8 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
 func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
     %fill = tensor.empty() : tensor<8xf32>
     %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
-    %mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
-    return %mapped_65 : tensor<8xf32>
+    %sqrt = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
+    return %sqrt : tensor<8xf32>
 }
 
 // CHECK-LABEL: func @map_ops
@@ -113,3 +113,27 @@ func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> te
 //  CHECK-NEXT:     linalg.yield %[[RES]] 
 //   CHECK-NOT:   linalg.map
 
+// -----
+
+func.func @elementwise_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
+    %fill = tensor.empty() : tensor<8xf32>
+    %add = linalg.elementwise
+      kind=#linalg.elementwise_kind<add>
+      ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>) -> tensor<8xf32>
+    %wqrt = linalg.elementwise
+      kind=#linalg.elementwise_kind<sqrt>
+      ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>) -> tensor<8xf32>
+    return %wqrt : tensor<8xf32>
+}
+
+// CHECK-LABEL: func @elementwise_ops
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
+//       CHECK:   %[[FUSED_OP:.+]] = linalg.generic
+//  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
+//  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
+//  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
+//  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]]
+//  CHECK-NEXT:     linalg.yield %[[SQRT]] 
+//   CHECK-NOT:   linalg.map

>From 8d2e8e0be55a1451e8b9774dddf9199158c98b2d Mon Sep 17 00:00:00 2001
From: Sam <srcarroll314 at gmail.com>
Date: Thu, 19 Jun 2025 18:13:22 -0500
Subject: [PATCH 7/7] fix formatting

---
 .../lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index fc435b47f5977..c1fc003d3f05d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -223,10 +223,10 @@ static void generateFusedElementwiseOpRegion(
   auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
   // Build the region of the fused op.
 
-  // Since some ops, like `linalg.map`, do not have block arguments for init operands
-  // then we first "generalize" the block by adding arguments for init operands when
-  // they aren't present. We detect this case by checking if
-  // `getOpOperandsMatchingBBargs() == getDpsInputOperands(); 
+  // Since some ops, like `linalg.map`, do not have block arguments for init
+  // operands then we first "generalize" the block by adding arguments for init
+  // operands when they aren't present. We detect this case by checking if
+  // `getOpOperandsMatchingBBargs() == getDpsInputOperands()
   Block &producerBlock = producer->getRegion(0).front();
   if (producer.getOpOperandsMatchingBBargs() ==
       producer.getDpsInputOperands()) {



More information about the Mlir-commits mailing list