[Mlir-commits] [mlir] 495e1a4 - [mlir] added a check in the walk to prevent catching a cos in a nested region (#190064)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 1 20:11:02 PDT 2026
Author: yebinchon
Date: 2026-04-01T20:10:56-07:00
New Revision: 495e1a42579c8f540303ab72e6449a448b89e537
URL: https://github.com/llvm/llvm-project/commit/495e1a42579c8f540303ab72e6449a448b89e537
DIFF: https://github.com/llvm/llvm-project/commit/495e1a42579c8f540303ab72e6449a448b89e537.diff
LOG: [mlir] added a check in the walk to prevent catching a cos in a nested region (#190064)
The walk in SincosFusion may detect a cos within a nested region of the
sin block. This triggers an assertion in `isBeforeInBlock` later on.
Added a check within the walk so it filters operations in nested
regions, which are not in the same block and should not be fused anyway.
---------
Co-authored-by: Yebin Chon <ychon at nvidia.com>
Added:
Modified:
mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
mlir/test/Dialect/Math/sincos-fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index 69407df201cfa..4d8027c604cdf 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -27,13 +27,11 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
math::CosOp cosOp = nullptr;
- sinOp->getBlock()->walk([&](math::CosOp op) {
+ for (auto op : sinOp->getBlock()->getOps<math::CosOp>())
if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
cosOp = op;
- return WalkResult::interrupt();
+ break;
}
- return WalkResult::advance();
- });
if (!cosOp)
return failure();
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
index 29fb9f12475b8..cf16f9f02f63a 100644
--- a/mlir/test/Dialect/Math/sincos-fusion.mlir
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -74,6 +74,29 @@ func.func @sincos_no_fusion_
diff erent_block(%arg0 : f32, %flag : i1) -> f32 {
func.return %0 : f32
}
+// CHECK-LABEL: func.func @sincos_no_fusion_nested_region(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: i1) -> (f32, f32) {
+// CHECK: %[[SIN:.*]] = math.sin %[[ARG0]] : f32
+// CHECK: %[[IF:.*]] = scf.if %[[ARG1]] -> (f32) {
+// CHECK: %[[COS:.*]] = math.cos %[[ARG0]] : f32
+// CHECK: scf.yield %[[COS]] : f32
+// CHECK: } else {
+// CHECK: scf.yield %[[SIN]] : f32
+// CHECK: }
+// CHECK: return %[[SIN]], %[[IF]] : f32, f32
+// CHECK: }
+func.func @sincos_no_fusion_nested_region(%arg0 : f32, %flag : i1) -> (f32, f32) {
+ %s = math.sin %arg0 : f32
+ %r = scf.if %flag -> f32 {
+ %c = math.cos %arg0 : f32
+ scf.yield %c : f32
+ } else {
+ scf.yield %s : f32
+ }
+ func.return %s, %r : f32, f32
+}
+
// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
More information about the Mlir-commits
mailing list