[Mlir-commits] [mlir] c2a9725 - [mlir][math] Add constant folding for sincos/cbrt (#194130)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Apr 25 21:03:10 PDT 2026


Author: Longsheng Mou
Date: 2026-04-26T12:03:06+08:00
New Revision: c2a9725b570407d1b057134e90d11c53bbc48b63

URL: https://github.com/llvm/llvm-project/commit/c2a9725b570407d1b057134e90d11c53bbc48b63
DIFF: https://github.com/llvm/llvm-project/commit/c2a9725b570407d1b057134e90d11c53bbc48b63.diff

LOG: [mlir][math] Add constant folding for sincos/cbrt (#194130)

Adds constant folder for `math.sincos` and `math.cbrt`.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Dialect/Math/IR/MathOps.cpp
    mlir/test/Dialect/Math/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index dc254e281f10f..926bac1de2b0d 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -366,6 +366,8 @@ def Math_CbrtOp : Math_FloatUnaryOp<"cbrt"> {
 
     Note: This op is not equivalent to powf(..., 1/3.0).
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -553,9 +555,8 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
 //===----------------------------------------------------------------------===//
 
 def Math_SincosOp : Math_Op<"sincos",
-    [SameOperandsAndResultShape,
-     DeclareOpInterfaceMethods<ArithFastMathInterface>,
-     AllTypesMatch<["operand", "sin", "cos"]>]> {
+    [SameOperandsAndResultType,
+     DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
   let summary = "sine and cosine of the specified value";
   let description = [{
     The `sincos` operation computes both the sine and cosine of a given value
@@ -583,6 +584,8 @@ def Math_SincosOp : Math_Op<"sincos",
   let extraClassDeclaration = [{
     std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 183ff511d3cb8..b900cb1911759 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -185,6 +185,24 @@ OpFoldResult math::Atan2Op::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// CbrtOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::CbrtOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+        switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+        case APFloat::Semantics::S_IEEEdouble:
+          return APFloat(cbrt(a.convertToDouble()));
+        case APFloat::Semantics::S_IEEEsingle:
+          return APFloat(cbrtf(a.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // CeilOp folder
 //===----------------------------------------------------------------------===//
@@ -284,7 +302,7 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
 }
 
 //===----------------------------------------------------------------------===//
-// SinCosOp getShapeForUnroll
+// SinCosOp
 //===----------------------------------------------------------------------===//
 
 std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
@@ -293,6 +311,35 @@ std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
   return std::nullopt;
 }
 
+LogicalResult math::SincosOp::fold(FoldAdaptor adaptor,
+                                   SmallVectorImpl<OpFoldResult> &result) {
+  auto foldSincos = [](const APFloat &a, double (*fnDouble)(double),
+                       float (*fnFloat)(float)) -> std::optional<APFloat> {
+    switch (APFloat::SemanticsToEnum(a.getSemantics())) {
+    case APFloat::Semantics::S_IEEEdouble:
+      return APFloat(fnDouble(a.convertToDouble()));
+    case APFloat::Semantics::S_IEEEsingle:
+      return APFloat(fnFloat(a.convertToFloat()));
+    default:
+      return {};
+    }
+  };
+
+  Attribute sinRes = constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(),
+      [&](const APFloat &a) { return foldSincos(a, sin, sinf); });
+  Attribute cosRes = constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(),
+      [&](const APFloat &a) { return foldSincos(a, cos, cosf); });
+
+  if (sinRes && cosRes) {
+    result.push_back(sinRes);
+    result.push_back(cosRes);
+    return success();
+  }
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
 // CountLeadingZerosOp folder
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Math/canonicalize.mlir b/mlir/test/Dialect/Math/canonicalize.mlir
index 228faa31781c4..3459164c5c0a7 100644
--- a/mlir/test/Dialect/Math/canonicalize.mlir
+++ b/mlir/test/Dialect/Math/canonicalize.mlir
@@ -647,3 +647,54 @@ func.func @fpowi_fold_failed() -> f32 {
   %0 = math.fpowi %cst, %c16777217_i32 : f32, i32
   return %0 : f32
 }
+
+// CHECK-LABEL: @sincos_fold_f32
+// CHECK: %[[sin:.+]] = arith.constant 0.84{{[0-9]+}} : f32
+// CHECK: %[[cos:.+]] = arith.constant 0.54{{[0-9]+}} : f32
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_f32() -> (f32, f32) {
+  %cst = arith.constant 1.000000e+00 : f32
+  %sin, %cos = math.sincos %cst : f32
+  return %sin, %cos : f32, f32
+}
+
+// CHECK-LABEL: @sincos_fold_f64
+// CHECK: %[[sin:.+]] = arith.constant 0.84{{[0-9]+}} : f64
+// CHECK: %[[cos:.+]] = arith.constant 0.54{{[0-9]+}} : f64
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_f64() -> (f64, f64) {
+  %cst = arith.constant 1.000000e+00 : f64
+  %sin, %cos = math.sincos %cst : f64
+  return %sin, %cos : f64, f64
+}
+
+// CHECK-LABEL: @sincos_fold_vec
+// CHECK: %[[sin:.+]] = arith.constant dense<[0.000000e+00, 0.84{{[0-9]+}}, 0.000000e+00, 0.84{{[0-9]+}}]> : vector<4xf32>
+// CHECK: %[[cos:.+]] = arith.constant dense<[1.000000e+00, 0.54{{[0-9]+}}, 1.000000e+00, 0.54{{[0-9]+}}]> : vector<4xf32>
+// CHECK: return %[[sin]], %[[cos]]
+func.func @sincos_fold_vec() -> (vector<4xf32>, vector<4xf32>) {
+  %cst = arith.constant dense<[0.0, 1.0, 0.0, 1.0]> : vector<4xf32>
+  %sin, %cos = math.sincos %cst : vector<4xf32>
+  return %sin, %cos : vector<4xf32>, vector<4xf32>
+}
+
+// CHECK-LABEL: @cbrt_fold
+// CHECK: %[[cst:.+]] = arith.constant 2.000000e+00 : f64
+// CHECK: %[[cst0:.+]] = arith.constant -2.000000e+00 : f32
+// CHECK: return %[[cst]], %[[cst0]]
+func.func @cbrt_fold() -> (f64, f32) {
+  %cst = arith.constant 8.000000e+00 : f64
+  %cst_0 = arith.constant -8.000000e+00 : f32
+  %0 = math.cbrt %cst : f64
+  %1 = math.cbrt %cst_0 : f32
+  return %0, %1 : f64, f32
+}
+
+// CHECK-LABEL: @cbrt_fold_vec
+// CHECK: %[[cst:.+]] = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.000000e+00, 2.000000e+00]> : vector<4xf32>
+// CHECK: return %[[cst]]
+func.func @cbrt_fold_vec() -> vector<4xf32> {
+  %cst = arith.constant dense<[-1.0, 0.0, 1.0, 8.0]> : vector<4xf32>
+  %0 = math.cbrt %cst : vector<4xf32>
+  return %0 : vector<4xf32>
+}


        


More information about the Mlir-commits mailing list