[Mlir-commits] [mlir] [mlir] added a check in the walk to prevent catching a cos in a nested region (PR #190064)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 1 18:07:39 PDT 2026


https://github.com/yebinchon updated https://github.com/llvm/llvm-project/pull/190064

>From 30c8adc4679003b688596213233d3549ff8b3692 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Wed, 1 Apr 2026 13:53:56 -0700
Subject: [PATCH 1/4] added a check in the walk to prevent catching an op in a
 nested region

---
 .../Dialect/Math/Transforms/SincosFusion.cpp  |  4 +++-
 mlir/test/Dialect/Math/sincos-fusion.mlir     | 23 +++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index 69407df201cfa..c6602eb293cbb 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -27,8 +27,10 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
     mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
 
     math::CosOp cosOp = nullptr;
+    Block* sinBlock = sinOp->getBlock();
     sinOp->getBlock()->walk([&](math::CosOp op) {
-      if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
+      if (op->getBlock() == sinBlock && op.getOperand() == operand && 
+          op.getFastmath() == sinFastMathFlags) {
         cosOp = op;
         return WalkResult::interrupt();
       }
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
index 29fb9f12475b8..cf16f9f02f63a 100644
--- a/mlir/test/Dialect/Math/sincos-fusion.mlir
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -74,6 +74,29 @@ func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
   func.return %0 : f32
 }
 
+// CHECK-LABEL:   func.func @sincos_no_fusion_nested_region(
+// CHECK-SAME:      %[[ARG0:.*]]: f32,
+// CHECK-SAME:      %[[ARG1:.*]]: i1) -> (f32, f32) {
+// CHECK:           %[[SIN:.*]] = math.sin %[[ARG0]] : f32
+// CHECK:           %[[IF:.*]] = scf.if %[[ARG1]] -> (f32) {
+// CHECK:             %[[COS:.*]] = math.cos %[[ARG0]] : f32
+// CHECK:             scf.yield %[[COS]] : f32
+// CHECK:           } else {
+// CHECK:             scf.yield %[[SIN]] : f32
+// CHECK:           }
+// CHECK:           return %[[SIN]], %[[IF]] : f32, f32
+// CHECK:         }
+func.func @sincos_no_fusion_nested_region(%arg0 : f32, %flag : i1) -> (f32, f32) {
+    %s = math.sin %arg0 : f32
+    %r = scf.if %flag -> f32 {
+      %c = math.cos %arg0 : f32
+      scf.yield %c : f32
+    } else {
+      scf.yield %s : f32
+    }
+    func.return %s, %r : f32, f32
+}
+
 // CHECK-LABEL:   func.func @sincos_fusion_preserve_fastmath(
 // CHECK-SAME:      %[[ARG0:.*]]: f32) -> (f32, f32) {
 // CHECK:           %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32

>From 4e1522fc0b9e01e60235328a2d7772c42ae4f519 Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Wed, 1 Apr 2026 14:02:32 -0700
Subject: [PATCH 2/4] formatting

---
 mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index c6602eb293cbb..c6a8a93b35db4 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -27,9 +27,9 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
     mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
 
     math::CosOp cosOp = nullptr;
-    Block* sinBlock = sinOp->getBlock();
+    Block *sinBlock = sinOp->getBlock();
     sinOp->getBlock()->walk([&](math::CosOp op) {
-      if (op->getBlock() == sinBlock && op.getOperand() == operand && 
+      if (op->getBlock() == sinBlock && op.getOperand() == operand &&
           op.getFastmath() == sinFastMathFlags) {
         cosOp = op;
         return WalkResult::interrupt();

>From 5747efb8c78022e472b92fe8568ef2355bfe550c Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Wed, 1 Apr 2026 14:19:28 -0700
Subject: [PATCH 3/4] changed from walk to for loop over cos ops

---
 mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index c6a8a93b35db4..d3d9c62423243 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -27,15 +27,13 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
     mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
 
     math::CosOp cosOp = nullptr;
-    Block *sinBlock = sinOp->getBlock();
-    sinOp->getBlock()->walk([&](math::CosOp op) {
-      if (op->getBlock() == sinBlock && op.getOperand() == operand &&
-          op.getFastmath() == sinFastMathFlags) {
+    for (auto op : sinOp->getBlock()->getOps<math::CosOp>()) {
+      if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
         cosOp = op;
         return WalkResult::interrupt();
       }
       return WalkResult::advance();
-    });
+    }
 
     if (!cosOp)
       return failure();

>From 88ff40a1e6bec39aaf852e6261f98cdd3ada5fff Mon Sep 17 00:00:00 2001
From: Yebin Chon <ychon at nvidia.com>
Date: Wed, 1 Apr 2026 18:07:23 -0700
Subject: [PATCH 4/4] fixed errors changing from walk to for loop

---
 mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index d3d9c62423243..5ff4f7c626347 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -30,9 +30,8 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
     for (auto op : sinOp->getBlock()->getOps<math::CosOp>()) {
       if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
         cosOp = op;
-        return WalkResult::interrupt();
+        break;
       }
-      return WalkResult::advance();
     }
 
     if (!cosOp)



More information about the Mlir-commits mailing list