[Mlir-commits] [mlir] [MLIR] Add sincos fusion pass (PR #161413)
Asher Mancinelli
llvmlistbot at llvm.org
Tue Sep 30 12:46:42 PDT 2025
https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/161413
>From c86c241b9338c31031586002bb9d7ac472ba6e3b Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 08:44:24 -0700
Subject: [PATCH 1/2] [MLIR] Add sincos fusion pass
We see performance improvements from using sincos to reuse calculations
in hot loops that compute sin() and cos() on the same operand.
Add a pass to identify sin() and cos() calls in the same block with the
same operand and fast-math flags, and fuse them into a sincos op.
Follow-up to:
* #160561
* #160772
---
.../mlir/Dialect/Math/Transforms/Passes.td | 8 ++
.../Dialect/Math/Transforms/CMakeLists.txt | 1 +
.../Dialect/Math/Transforms/SincosFusion.cpp | 77 +++++++++++++++++++
mlir/test/Dialect/Math/sincos-fusion.mlir | 64 +++++++++++++++
4 files changed, 150 insertions(+)
create mode 100644 mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
create mode 100644 mlir/test/Dialect/Math/sincos-fusion.mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
index 4d415aeac8f58..48346abd84285 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td
@@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> {
];
}
+def MathSincosFusionPass : Pass<"math-sincos-fusion"> {
+ let summary = "Fuse sin and cos operations.";
+ let description = [{
+ Fuse sin and cos operations into a sincos operation.
+ }];
+ let dependentDialects = ["math::MathDialect"];
+}
+
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
index ff62b515533c3..8899c3a1d1a42 100644
--- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms
ExpandOps.cpp
ExtendToSupportedTypes.cpp
PolynomialApproximation.cpp
+ SincosFusion.cpp
UpliftToFMA.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
new file mode 100644
index 0000000000000..a373cf70b5541
--- /dev/null
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -0,0 +1,77 @@
+//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::math;
+
+namespace {
+
+/// Fuse a math.sin and math.cos in the same block that use the same operand and
+/// have identical fastmath flags into a single math.sincos.
+struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
+ using OpRewritePattern<math::SinOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::SinOp sinOp,
+ PatternRewriter &rewriter) const override {
+ Value operand = sinOp.getOperand();
+ mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath();
+
+ math::CosOp cosOp = nullptr;
+ sinOp->getBlock()->walk([&](math::CosOp op) {
+ if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) {
+ cosOp = op;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ if (!cosOp)
+ return failure();
+
+ Type elemType = sinOp.getType();
+ auto sincos = rewriter.create<math::SincosOp>(
+ sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
+ sinOp.getFastmathAttr());
+
+ rewriter.replaceOp(sinOp, sincos.getSin());
+ rewriter.replaceOp(cosOp, sincos.getCos());
+ return success();
+ }
+};
+
+} // namespace
+
+namespace mlir::math {
+#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS
+#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
+} // namespace mlir::math
+
+namespace {
+
+struct MathSincosFusionPass final
+ : math::impl::MathSincosFusionPassBase<MathSincosFusionPass> {
+ using MathSincosFusionPassBase::MathSincosFusionPassBase;
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.add<SincosFusionPattern>(&getContext());
+
+ GreedyRewriteConfig config;
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(patterns), config)))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
new file mode 100644
index 0000000000000..9abf576c858bd
--- /dev/null
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s
+
+// CHECK-LABEL: func.func @sincos_fusion(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK: }
+func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+ %0 = math.sin %arg0 : f32
+ %1 = math.cos %arg0 : f32
+
+ %2 = math.cos %arg1 : f32
+ %3 = math.sin %arg1 : f32
+
+ func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
+// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
+// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
+// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK: }
+func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) {
+ %0 = math.sin %arg0 fastmath<contract> : f32
+ %1 = math.cos %arg0 : f32
+ func.return %0, %1 : f32, f32
+}
+
+// CHECK-LABEL: func.func @sincos_no_fusion_different_block(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 {
+// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) {
+// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32
+// CHECK: scf.yield %[[VAL_1]] : f32
+// CHECK: } else {
+// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32
+// CHECK: scf.yield %[[VAL_2]] : f32
+// CHECK: }
+// CHECK: return %[[VAL_0]] : f32
+// CHECK: }
+func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 {
+ %0 = scf.if %flag -> f32 {
+ %s = math.sin %arg0 : f32
+ scf.yield %s : f32
+ } else {
+ %c = math.cos %arg0 : f32
+ scf.yield %c : f32
+ }
+ func.return %0 : f32
+}
+
+// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath(
+// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath<contract> : f32
+// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32
+// CHECK: }
+func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) {
+ %0 = math.sin %arg0 fastmath<contract> : f32
+ %1 = math.cos %arg0 fastmath<contract> : f32
+ func.return %0, %1 : f32, f32
+}
>From 3e072be6952a2730050007c7fd42b726ab86f38f Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 12:46:13 -0700
Subject: [PATCH 2/2] Insertion point should be the dominant op
---
.../Dialect/Math/Transforms/SincosFusion.cpp | 12 ++++++----
mlir/test/Dialect/Math/sincos-fusion.mlir | 22 +++++++++++++++++++
2 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
index a373cf70b5541..717c4f0867dab 100644
--- a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp
@@ -9,7 +9,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
+// #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -39,10 +39,14 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
if (!cosOp)
return failure();
+ Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
+ : cosOp.getOperation();
+ rewriter.setInsertionPoint(firstOp);
+
Type elemType = sinOp.getType();
- auto sincos = rewriter.create<math::SincosOp>(
- sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
- sinOp.getFastmathAttr());
+ auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
+ TypeRange{elemType, elemType}, operand,
+ sinOp.getFastmathAttr());
rewriter.replaceOp(sinOp, sincos.getSin());
rewriter.replaceOp(cosOp, sincos.getCos());
diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir
index 9abf576c858bd..29fb9f12475b8 100644
--- a/mlir/test/Dialect/Math/sincos-fusion.mlir
+++ b/mlir/test/Dialect/Math/sincos-fusion.mlir
@@ -17,6 +17,28 @@ func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
func.return %0, %1, %2, %3 : f32, f32, f32, f32
}
+func.func private @sink(%arg0 : f32)
+
+// CHECK: func.func private @sink(f32)
+// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance(
+// CHECK-SAME: %[[ARG0:.*]]: f32,
+// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
+// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
+// CHECK: call @sink(%[[VAL_0]]) : (f32) -> ()
+// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
+// CHECK: call @sink(%[[VAL_3]]) : (f32) -> ()
+// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
+// CHECK: }
+func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
+ %0 = math.sin %arg0 : f32
+ func.call @sink(%0) : (f32) -> ()
+ %1 = math.cos %arg0 : f32
+ %2 = math.cos %arg1 : f32
+ func.call @sink(%2) : (f32) -> ()
+ %3 = math.sin %arg1 : f32
+ func.return %0, %1, %2, %3 : f32, f32, f32, f32
+}
+
// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32
More information about the Mlir-commits
mailing list