[Mlir-commits] [mlir] c2a9725 - [mlir][math] Add constant folding for sincos/cbrt (#194130)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Apr 25 21:03:10 PDT 2026
Author: Longsheng Mou
Date: 2026-04-26T12:03:06+08:00
New Revision: c2a9725b570407d1b057134e90d11c53bbc48b63
URL: https://github.com/llvm/llvm-project/commit/c2a9725b570407d1b057134e90d11c53bbc48b63
DIFF: https://github.com/llvm/llvm-project/commit/c2a9725b570407d1b057134e90d11c53bbc48b63.diff
LOG: [mlir][math] Add constant folding for sincos/cbrt (#194130)
Adds constant folder for `math.sincos` and `math.cbrt`.
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Dialect/Math/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index dc254e281f10f..926bac1de2b0d 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -366,6 +366,8 @@ def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> {
Note: This op is not equivalent to powf(..., 1/3.0).
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@@ -553,9 +555,8 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
//===----------------------------------------------------------------------===//
def Math_SincosOp : Math_Op<"sincos",
- [SameOperandsAndResultShape,
- DeclareOpInterfaceMethods<ArithFastMathInterface>,
- AllTypesMatch<["operand", "sin", "cos"]>]> {
+ [SameOperandsAndResultType,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "sine and cosine of the specified value";
let description = [{
The `sincos` operation computes both the sine and cosine of a given value
@@ -583,6 +584,8 @@ def Math_SincosOp : Math_Op<"sincos",
let extraClassDeclaration = [{
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 183ff511d3cb8..b900cb1911759 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -185,6 +185,24 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// CbrtOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CbrtOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
+ return APFloat(cbrt(a.convertToDouble()));
+ case APFloat::Semantics::S_IEEEsingle:
+ return APFloat(cbrtf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// CeilOp folder
//===----------------------------------------------------------------------===//
@@ -284,7 +302,7 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
}
//===----------------------------------------------------------------------===//
-// SinCosOp getShapeForUnroll
+// SinCosOp
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
@@ -293,6 +311,35 @@ std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
return std::nullopt;
}
+LogicalResult math::SincosOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &result) {
+ auto foldSincos = [](const APFloat &a, double (*fnDouble)(double),
+ float (*fnFloat)(float)) -> std::optional<APFloat> {
+ switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+ case APFloat::Semantics::S_IEEEdouble:
+ return APFloat(fnDouble(a.convertToDouble()));
+ case APFloat::Semantics::S_IEEEsingle:
+ return APFloat(fnFloat(a.convertToFloat()));
+ default:
+ return {};
+ }
+ };
+
+ Attribute sinRes = constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(),
+ [&](const APFloat &a) { return foldSincos(a, sin, sinf); });
+ Attribute cosRes = constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(),
+ [&](const APFloat &a) { return foldSincos(a, cos, cosf); });
+
+ if (sinRes && cosRes) {
+ result.push_back(sinRes);
+ result.push_back(cosRes);
+ return success();
+ }
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 228faa31781c4..3459164c5c0a7 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -647,3 +647,54 @@ func.func @fpowi_fold_failed() -> f32 {
%0 = math.fpowi %cst, %c16777217_i32 : f32, i32
return %0 : f32
}
+
+// CHECK-LABEL: @sincos_fold_f32
+// CHECK: %[[sin:.+]] = arith.constant 0.84{{[0-9]+}} : f32
+// CHECK: %[[cos:.+]] = arith.constant 0.54{{[0-9]+}} : f32
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_f32() -> (f32, f32) {
+ %cst = arith.constant 1.000000e+00 : f32
+ %sin, %cos = math.sincos %cst : f32
+ return %sin, %cos : f32, f32
+}
+
+// CHECK-LABEL: @sincos_fold_f64
+// CHECK: %[[sin:.+]] = arith.constant 0.84{{[0-9]+}} : f64
+// CHECK: %[[cos:.+]] = arith.constant 0.54{{[0-9]+}} : f64
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_f64() -> (f64, f64) {
+ %cst = arith.constant 1.000000e+00 : f64
+ %sin, %cos = math.sincos %cst : f64
+ return %sin, %cos : f64, f64
+}
+
+// CHECK-LABEL: @sincos_fold_vec
+// CHECK: %[[sin:.+]] = arith.constant dense<[0.000000e+00, 0.84{{[0-9]+}}, 0.000000e+00, 0.84{{[0-9]+}}]> : vector<4xf32>
+// CHECK: %[[cos:.+]] = arith.constant dense<[1.000000e+00, 0.54{{[0-9]+}}, 1.000000e+00, 0.54{{[0-9]+}}]> : vector<4xf32>
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_vec() -> (vector<4xf32>, vector<4xf32>) {
+ %cst = arith.constant dense<[0.0, 1.0, 0.0, 1.0]> : vector<4xf32>
+ %sin, %cos = math.sincos %cst : vector<4xf32>
+ return %sin, %cos : vector<4xf32>, vector<4xf32>
+}
+
+// CHECK-LABEL: @cbrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant -2.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]]
+func.func @cbrt_fold() -> (f64, f32) {
+ %cst = arith.constant 8.000000e+00 : f64
+ %cst_0 = arith.constant -8.000000e+00 : f32
+ %0 = math.cbrt %cst : f64
+ %1 = math.cbrt %cst_0 : f32
+ return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @cbrt_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, 2.000000e+00]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @cbrt_fold_vec() -> vector<4xf32> {
+ %cst = arith.constant dense<[-1.0, 0.0, 1.0, 8.0]> : vector<4xf32>
+ %0 = math.cbrt %cst : vector<4xf32>
+ return %0 : vector<4xf32>
+}
More information about the Mlir-commits
mailing list