[Mlir-commits] [mlir] [MLIR] Add sincos fusion pass (PR #161413)

Asher Mancinelli llvmlistbot at llvm.org
Tue Sep 30 12:46:42 PDT 2025


https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/161413

>From c86c241b9338c31031586002bb9d7ac472ba6e3b Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 08:44:24 -0700
Subject: [PATCH 1/2] [MLIR] Add sincos fusion pass

We see performance improvements from using sincos to reuse calculations
in hot loops that compute sin() and cos() on the same operand.
Add a pass to identify sin() and cos() calls in the same block with the
same operand and fast-math flags, and fuse them into a sincos op.

Follow-up to:
* #160561
* #160772
---
 .../mlir/Dialect/Math/Transforms/Passes.td    |  8 ++
 .../Dialect/Math/Transforms/CMakeLists.txt    |  1 +
 .../Dialect/Math/Transforms/SincosFusion.cpp  | 77 +++++++++++++++++++
 mlir/test/Dialect/Math/sincos-fusion.mlir     | 64 +++++++++++++++
 4 files changed, 150 insertions(+)
 create mode 100644 mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
 create mode 100644 mlir/test/Dialect/Math/sincos-fusion.mlir

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 4d415aeac8f58..48346abd84285 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> {
   ];
 }
 
+def MathSincosFusionPass : Pass<"math-sincos-fusion"> {
+  let summary = "Fuse sin and cos operations.";
+  let description = [{
+    Fuse sin and cos operations into a sincos operation.
+  }];
+  let dependentDialects = ["math::MathDialect"];
+}
+
 #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index ff62b515533c3..8899c3a1d1a42 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
   ExpandOps.cpp
   ExtendToSupportedTypes.cpp
   PolynomialApproximation.cpp
+  SincosFusion.cpp
   UpliftToFMA.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
new file mode 100644
index 0000000000000..a373cf70b5541
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -0,0 +1,77 @@
+//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::math;
+
+namespace {
+
+/// Fuse a math.sin and math.cos in the same block that use the same operand and
+/// have identical fastmath flags into a single math.sincos.
+struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
+  using OpRewritePattern<math::SinOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::SinOp sinOp,
+                                PatternRewriter &rewriter) const override {
+    Value operand = sinOp.getOperand();
+    mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
+
+    math::CosOp cosOp = nullptr;
+    sinOp->getBlock()->walk([&](math::CosOp op) {
+      if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
+        cosOp = op;
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
+
+    if (!cosOp)
+      return failure();
+
+    Type elemType = sinOp.getType();
+    auto sincos = rewriter.create<math::SincosOp>(
+        sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
+        sinOp.getFastmathAttr());
+
+    rewriter.replaceOp(sinOp, sincos.getSin());
+    rewriter.replaceOp(cosOp, sincos.getCos());
+    return success();
+  }
+};
+
+} // namespace
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+namespace {
+
+struct MathSincosFusionPass final
+    : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
+  using MathSincosFusionPassBase::MathSincosFusionPassBase;
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    patterns.add<SincosFusionPattern>(&getContext());
+
+    GreedyRewriteConfig config;
+    if (failed(
+            applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
new file mode 100644
index 0000000000000..9abf576c858bd
--- /dev/null
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @sincos_fusion(
+// CHECK-SAME:      %[[ARG0:.*]]: f32,
+// CHECK-SAME:      %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK:           %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK:           %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK:           return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK:         }
+func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+    %0 = math.sin %arg0 : f32
+    %1 = math.cos %arg0 : f32
+
+    %2 = math.cos %arg1 : f32
+    %3 = math.sin %arg1 : f32
+
+    func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL:   func.func @sincos_fusion_no_match_fmf(
+// CHECK-SAME:      %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK:           %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
+// CHECK:           %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32
+// CHECK:           return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK:         }
+func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) {
+    %0 = math.sin %arg0 fastmath<contract> : f32
+    %1 = math.cos %arg0 : f32
+    func.return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL:   func.func @sincos_no_fusion_different_block(
+// CHECK-SAME:      %[[ARG0:.*]]: f32,
+// CHECK-SAME:      %[[ARG1:.*]]: i1) -> f32 {
+// CHECK:           %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) {
+// CHECK:             %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32
+// CHECK:             scf.yield %[[VAL_1]] : f32
+// CHECK:           } else {
+// CHECK:             %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32
+// CHECK:             scf.yield %[[VAL_2]] : f32
+// CHECK:           }
+// CHECK:           return %[[VAL_0]] : f32
+// CHECK:         }
+func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
+  %0 = scf.if %flag -> f32 {
+    %s = math.sin %arg0 : f32
+    scf.yield %s : f32
+  } else {
+    %c = math.cos %arg0 : f32
+    scf.yield %c : f32
+  }
+  func.return %0 : f32
+}
+
+// CHECK-LABEL:   func.func @sincos_fusion_preserve_fastmath(
+// CHECK-SAME:      %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK:           %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
+// CHECK:           return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK:         }
+func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) {
+    %0 = math.sin %arg0 fastmath<contract> : f32
+    %1 = math.cos %arg0 fastmath<contract> : f32
+    func.return %0, %1 : f32, f32
+}

>From 3e072be6952a2730050007c7fd42b726ab86f38f Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 12:46:13 -0700
Subject: [PATCH 2/2] Insertion point should be the dominant op

---
 .../Dialect/Math/Transforms/SincosFusion.cpp  | 12 ++++++----
 mlir/test/Dialect/Math/sincos-fusion.mlir     | 22 +++++++++++++++++++
 2 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index a373cf70b5541..717c4f0867dab 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -9,7 +9,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Math/Transforms/Passes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
+// #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -39,10 +39,14 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
     if (!cosOp)
       return failure();
 
+    Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
+                                                       : cosOp.getOperation();
+    rewriter.setInsertionPoint(firstOp);
+
     Type elemType = sinOp.getType();
-    auto sincos = rewriter.create<math::SincosOp>(
-        sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
-        sinOp.getFastmathAttr());
+    auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
+                                         TypeRange{elemType, elemType}, operand,
+                                         sinOp.getFastmathAttr());
 
     rewriter.replaceOp(sinOp, sincos.getSin());
     rewriter.replaceOp(cosOp, sincos.getCos());
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
index 9abf576c858bd..29fb9f12475b8 100644
--- a/mlir/test/Dialect/Math/sincos-fusion.mlir
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -17,6 +17,28 @@ func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
     func.return %0, %1, %2, %3 : f32, f32, f32, f32
 }
 
+func.func private @sink(%arg0 : f32)
+
+// CHECK:         func.func private @sink(f32)
+// CHECK-LABEL:   func.func @sincos_ensure_ssa_dominance(
+// CHECK-SAME:      %[[ARG0:.*]]: f32,
+// CHECK-SAME:      %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK:           %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK:           call @sink(%[[VAL_0]]) : (f32) -> ()
+// CHECK:           %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK:           call @sink(%[[VAL_3]]) : (f32) -> ()
+// CHECK:           return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK:         }
+func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+    %0 = math.sin %arg0 : f32
+    func.call @sink(%0) : (f32) -> ()
+    %1 = math.cos %arg0 : f32
+    %2 = math.cos %arg1 : f32
+    func.call @sink(%2) : (f32) -> ()
+    %3 = math.sin %arg1 : f32
+    func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
 // CHECK-LABEL:   func.func @sincos_fusion_no_match_fmf(
 // CHECK-SAME:      %[[ARG0:.*]]: f32) -> (f32, f32) {
 // CHECK:           %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32



More information about the Mlir-commits mailing list