[Mlir-commits] [mlir] [MLIR] Add sincos op to math dialect (PR #160772)

Asher Mancinelli llvmlistbot at llvm.org
Tue Sep 30 07:23:49 PDT 2025


https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/160772

>From 3fcbaee738982ab6df9df883dbf0fdc63302862e Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 08:48:11 -0700
Subject: [PATCH 1/7] [MLIR] Add sincos operation to math dialect

Now that `sincos` is a supported intrinsic in the LLVM dialect (https://github.com/llvm/llvm-project/pull/160561) we are able to add the corresponding operation in the math dialect.

We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing sine and cosine together can benefit performance. We would like to have a way to represent sincos in the math dialect.

Two parts I'm unsure about:
* What do we think of the assembly format? `math.sincos %floatlike : f32 -> f32, f32` With a custom assembly format we could omit the `->` and everything after, but I couldn't get the ODS to do that. Open to suggestions.
* I implement `getShapeForUnroll()` here, but where is the best place to test the unroller interfaces? I'll keep poking around after sending this out for review.
---
 mlir/include/mlir/Dialect/Math/IR/MathOps.td  | 38 ++++++++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 93 ++++++++++++++++++-
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 33 +++++++
 mlir/lib/Dialect/Math/IR/MathOps.cpp          | 22 +++++
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     | 39 ++++++++
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 10 ++
 mlir/test/Dialect/Math/ops.mlir               | 12 +++
 7 files changed, 246 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index cfd8c4b8f11f7..a7e79f2efd4c5 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -510,6 +510,44 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SinCosOp
+//===----------------------------------------------------------------------===//
+
+def Math_SincosOp : Math_Op<"sincos",
+    [SameOperandsAndResultShape,
+     DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+  let summary = "sine and cosine of the specified value";
+  let description = [{
+    The `sincos` operation computes both the sine and cosine of a given value
+    simultaneously. It takes one operand of floating point type (i.e., scalar,
+    tensor or vector) and returns two results of the same type. This operation
+    can be more efficient than computing sine and cosine separately when both
+    values are needed.
+
+    Example:
+
+    ```mlir
+    // Scalar sine and cosine values.
+    %sin, %cos = math.sincos %input : f64 `->` f64, f64
+    ```
+  }];
+
+  let arguments = (ins FloatLike:$operand,
+      DefaultValuedAttr<Arith_FastMathAttr,
+                        "::mlir::arith::FastMathFlags::none">:$fastmath);
+  let results = (outs FloatLike:$sin, FloatLike:$cos);
+
+  let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+                          attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }];
+
+  let extraClassDeclaration = [{
+    std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+  }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // CountLeadingZerosOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index a95263bb55f69..16d765f2b2561 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
                       LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
                       LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
                       LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
-                      LLVM::SqrtOp>();
+                      LLVM::SincosOp, LLVM::SqrtOp>();
 
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,94 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
   });
 }
 
+// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value input = adaptor.getOperand();
+    Type inputType = input.getType();
+    auto convertedInput = maybeExt(input, rewriter);
+    auto computeType = convertedInput.getType();
+
+    StringRef sincosFunc;
+    if (isa<Float32Type>(computeType)) {
+      const arith::FastMathFlags flag = op.getFastmath();
+      const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+      sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+    } else if (isa<Float64Type>(computeType)) {
+      sincosFunc = "__nv_sincos";
+    } else {
+      return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos");
+    }
+
+    auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+    
+    Value sinPtr, cosPtr;
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      auto *scope = op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+      assert(scope && "Expected op to be inside automatic allocation scope");
+      rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+      auto one = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+      sinPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+      cosPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+    }
+
+    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op);
+
+    auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
+    auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+
+    rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+                            maybeTrunc(cosResult, inputType, rewriter)});
+    return success();
+  }
+
+private:
+  Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+    if (isa<Float16Type, BFloat16Type>(operand.getType())) {
+      return rewriter.create<LLVM::FPExtOp>(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+    }
+    return operand;
+  }
+
+  Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+    if (operand.getType() != type)
+      return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+    return operand;
+  }
+
+  void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+                        StringRef funcName, Value input, Value sinPtr, Value cosPtr,
+                        Operation *op) const {
+    auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+    auto ptrType = sinPtr.getType();
+    
+    SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+    auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+    
+    auto funcAttr = StringAttr::get(op->getContext(), funcName);
+    auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+    
+    if (!funcOp) {
+      auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+      assert(parentFunc && "expected there to be a parent function");
+      OpBuilder b(parentFunc);
+      
+      auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+      funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    }
+    
+    SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+    rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+  }
+};
+
 template <typename OpTy>
 static void populateOpPatterns(const LLVMTypeConverter &converter,
                                RewritePatternSet &patterns,
@@ -589,6 +677,9 @@ void mlir::populateLibDeviceConversionPatterns(
                                   "__nv_tan", "__nv_fast_tanf");
   populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
                                    "__nv_tanh");
+
+  // Custom pattern for sincos since it returns two values
+  patterns.add<SincosOpLowering>(converter, benefit);
 }
 
 void mlir::populateGpuToNVVMConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 853f45498ac52..73a003ef4e6c1 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
                           LLVM::CountTrailingZerosOp>;
 using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
 
+// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
+    auto loc = op.getLoc();
+    auto operandType = adaptor.getOperand().getType();
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    auto sinType = typeConverter.convertType(op.getSin().getType());
+    auto cosType = typeConverter.convertType(op.getCos().getType());
+    if (!llvmOperandType || !sinType || !cosType)
+      return failure();
+
+    ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
+
+    auto structType = LLVM::LLVMStructType::getLiteral(
+        rewriter.getContext(), {llvmOperandType, llvmOperandType});
+
+    auto sincosOp = rewriter.create<LLVM::SincosOp>(
+        loc, structType, adaptor.getOperand(), attrs.getAttrs());
+
+    auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
+    auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
+
+    rewriter.replaceOp(op, {sinValue, cosValue});
+    return success();
+  }
+};
+
 // A `expm1` is converted into `exp - 1`.
 struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
   using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     RoundEvenOpLowering,
     RoundOpLowering,
     RsqrtOpLowering,
+    SincosOpLowering,
     SinOpLowering,
     SinhOpLowering,
     ASinOpLowering,
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index a21631cbf8510..f0bf62770d4cc 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -284,6 +284,28 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// SinCosOp verifier and getShapeForUnroll
+//===----------------------------------------------------------------------===//
+
+LogicalResult math::SincosOp::verify() {
+  Type operandType = getOperand().getType();
+  Type sinType = getSin().getType();
+  Type cosType = getCos().getType();
+
+  if (operandType != sinType || operandType != cosType) {
+    return emitOpError("result types must match operand type");
+  }
+
+  return success();
+}
+
+std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
+  if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
+    return llvm::to_vector<4>(vt.getShape());
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // CountLeadingZerosOp folder
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ef06af3ad3163..cdefc4d6098c7 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
     func.return %result32, %result64 : f32, f64
   }
 }
+
+gpu.module @test_module_56 {
+  // CHECK: gpu.module @test_module_56
+
+  // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
+  // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)
+
+  // CHECK-LABEL: func @gpu_sincos
+  // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+  func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+    // CHECK-COUNT-6: llvm.alloca
+    // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+    // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+    // CHECK-COUNT-2: llvm.fptrunc
+    // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+    // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+    %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16
+    %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32
+    %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64
+    func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+  }
+
+  // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)
+
+  // CHECK-LABEL: func @gpu_sincos_fastmath
+  // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+  func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+    // CHECK-COUNT-6: llvm.alloca
+    // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+    // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+    // CHECK-COUNT-2: llvm.fptrunc
+    // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+    // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+    %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16 -> f16, f16
+    %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32 -> f32, f32
+    %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64 -> f64, f64
+    func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+  }
+}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f4541220fe4d2..9030ba9c93e55 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {
 
 // -----
 
+// CHECK-LABEL: func @sincos
+// CHECK-SAME: [[ARG0:%.+]]: f32
+func.func @sincos(%arg0: f32) {
+  // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
+  %0:2 = math.sincos %arg0 : f32 -> f32, f32
+  func.return
+}
+
+// -----
+
 // CHECK-LABEL: func @inverse_trigonometrics
 // CHECK-SAME: [[ARG0:%.+]]: f32
 func.func @inverse_trigonometrics(%arg0: f32) {
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index cb10fc4397ffc..5d3a8a6d87bed 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   return
 }
 
+// CHECK-LABEL: func @sincos(
+// CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+  // CHECK: %{{.*}} = math.sincos %[[F]] : f32
+  %0:2 = math.sincos %f : f32 -> f32, f32
+  // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
+  %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32>
+  // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
+  %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32>
+  return
+}
+
 // CHECK-LABEL: func @erf(
 // CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
 func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {

>From 6e2b34c11fd688312bcab520d4e74b5e12e10ae9 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:24:24 -0700
Subject: [PATCH 2/7] Formatting

---
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 40 +++++++++++--------
 1 file changed, 24 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 16d765f2b2561..2c0a3305518e1 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -482,29 +482,35 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
     StringRef sincosFunc;
     if (isa<Float32Type>(computeType)) {
       const arith::FastMathFlags flag = op.getFastmath();
-      const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+      const bool useApprox =
+          ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
       sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
     } else if (isa<Float64Type>(computeType)) {
       sincosFunc = "__nv_sincos";
     } else {
-      return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos");
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported operand type for sincos");
     }
 
     auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
-    
+
     Value sinPtr, cosPtr;
     {
       OpBuilder::InsertionGuard guard(rewriter);
-      auto *scope = op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+      auto *scope =
+          op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
       assert(scope && "Expected op to be inside automatic allocation scope");
       rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
       auto one = rewriter.create<LLVM::ConstantOp>(
           loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
-      sinPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
-      cosPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+      sinPtr =
+          rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+      cosPtr =
+          rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
     }
 
-    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op);
+    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+                     op);
 
     auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
     auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
@@ -517,7 +523,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
 private:
   Value maybeExt(Value operand, PatternRewriter &rewriter) const {
     if (isa<Float16Type, BFloat16Type>(operand.getType())) {
-      return rewriter.create<LLVM::FPExtOp>(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+      return rewriter.create<LLVM::FPExtOp>(
+          operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
     }
     return operand;
   }
@@ -529,26 +536,27 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
   }
 
   void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
-                        StringRef funcName, Value input, Value sinPtr, Value cosPtr,
-                        Operation *op) const {
+                        StringRef funcName, Value input, Value sinPtr,
+                        Value cosPtr, Operation *op) const {
     auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
     auto ptrType = sinPtr.getType();
-    
+
     SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
     auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
-    
+
     auto funcAttr = StringAttr::get(op->getContext(), funcName);
-    auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
-    
+    auto funcOp =
+        SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
     if (!funcOp) {
       auto parentFunc = op->getParentOfType<FunctionOpInterface>();
       assert(parentFunc && "expected there to be a parent function");
       OpBuilder b(parentFunc);
-      
+
       auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
       funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
     }
-    
+
     SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
     rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
   }

>From dfea012c6c10386620c341eba82a690af926e969 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:26:32 -0700
Subject: [PATCH 3/7] Remove needless comment

---
 mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2c0a3305518e1..2b46a01c3b0e5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -466,7 +466,6 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
   });
 }
 
-// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls
 struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
   using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
 

>From 1a24ecca6e955b70c09f11bccc4ee1bc6b41a1fc Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:43:30 -0700
Subject: [PATCH 4/7] Remove braces on single-line if

---
 mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +--
 mlir/lib/Dialect/Math/IR/MathOps.cpp                   | 3 +--
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2b46a01c3b0e5..f8f2104d2bd6a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -521,10 +521,9 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
 
 private:
   Value maybeExt(Value operand, PatternRewriter &rewriter) const {
-    if (isa<Float16Type, BFloat16Type>(operand.getType())) {
+    if (isa<Float16Type, BFloat16Type>(operand.getType()))
       return rewriter.create<LLVM::FPExtOp>(
           operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
-    }
     return operand;
   }
 
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index f0bf62770d4cc..0de5636c27c3f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -293,9 +293,8 @@ LogicalResult math::SincosOp::verify() {
   Type sinType = getSin().getType();
   Type cosType = getCos().getType();
 
-  if (operandType != sinType || operandType != cosType) {
+  if (operandType != sinType || operandType != cosType)
     return emitOpError("result types must match operand type");
-  }
 
   return success();
 }

>From 145610739e99193970accde6f4e9596eb7fe4f3b Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 14:18:40 -0700
Subject: [PATCH 5/7] Refine assembly format

---
 mlir/include/mlir/Dialect/Math/IR/MathOps.td      |  7 ++++---
 mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir   | 12 ++++++------
 mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir |  2 +-
 mlir/test/Dialect/Math/ops.mlir                   |  6 +++---
 4 files changed, 14 insertions(+), 13 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index a7e79f2efd4c5..b4212056694e9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -516,7 +516,8 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
 
 def Math_SincosOp : Math_Op<"sincos",
     [SameOperandsAndResultShape,
-     DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+     DeclareOpInterfaceMethods<ArithFastMathInterface>,
+     AllTypesMatch<["operand", "sin", "cos"]>]> {
   let summary = "sine and cosine of the specified value";
   let description = [{
     The `sincos` operation computes both the sine and cosine of a given value
@@ -529,7 +530,7 @@ def Math_SincosOp : Math_Op<"sincos",
 
     ```mlir
     // Scalar sine and cosine values.
-    %sin, %cos = math.sincos %input : f64 `->` f64, f64
+    %sin, %cos = math.sincos %input : f64
     ```
   }];
 
@@ -539,7 +540,7 @@ def Math_SincosOp : Math_Op<"sincos",
   let results = (outs FloatLike:$sin, FloatLike:$cos);
 
   let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
-                          attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }];
+                          attr-dict `:` type($operand) }];
 
   let extraClassDeclaration = [{
     std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index cdefc4d6098c7..a4b5dde8a2187 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1125,9 +1125,9 @@ gpu.module @test_module_56 {
     // CHECK-COUNT-2: llvm.fptrunc
     // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
     // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
-    %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16
-    %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32
-    %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64
+    %sin16, %cos16 = math.sincos %arg_f16 : f16
+    %sin32, %cos32 = math.sincos %arg_f32 : f32
+    %sin64, %cos64 = math.sincos %arg_f64 : f64
     func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
   }
 
@@ -1142,9 +1142,9 @@ gpu.module @test_module_56 {
     // CHECK-COUNT-2: llvm.fptrunc
     // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
     // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
-    %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16 -> f16, f16
-    %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32 -> f32, f32
-    %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64 -> f64, f64
+    %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16
+    %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32
+    %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64
     func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
   }
 }
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 9030ba9c93e55..f7d27120d4207 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -234,7 +234,7 @@ func.func @trigonometrics(%arg0: f32) {
 // CHECK-SAME: [[ARG0:%.+]]: f32
 func.func @sincos(%arg0: f32) {
   // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
-  %0:2 = math.sincos %arg0 : f32 -> f32, f32
+  %0:2 = math.sincos %arg0 : f32
   func.return
 }
 
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 5d3a8a6d87bed..f085d1c62ea86 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -66,11 +66,11 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
 // CHECK-SAME:            %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
 func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
   // CHECK: %{{.*}} = math.sincos %[[F]] : f32
-  %0:2 = math.sincos %f : f32 -> f32, f32
+  %0:2 = math.sincos %f : f32
   // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
-  %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32>
+  %1:2 = math.sincos %v : vector<4xf32>
   // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
-  %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32>
+  %2:2 = math.sincos %t : tensor<4x4x?xf32>
   return
 }
 

>From faa84a92b9baa9a1ed4a184abc84b4db44b329c4 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 17:22:51 -0700
Subject: [PATCH 6/7] Remove custom verifier; clean up FMF handling

---
 mlir/include/mlir/Dialect/Math/IR/MathOps.td        |  2 --
 .../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp   |  2 +-
 mlir/lib/Dialect/Math/IR/MathOps.cpp                | 13 +------------
 3 files changed, 2 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b4212056694e9..af65af6fedec6 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -545,8 +545,6 @@ def Math_SincosOp : Math_Op<"sincos",
   let extraClassDeclaration = [{
     std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
   }];
-
-  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index f8f2104d2bd6a..852c50c965f11 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -482,7 +482,7 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
     if (isa<Float32Type>(computeType)) {
       const arith::FastMathFlags flag = op.getFastmath();
       const bool useApprox =
-          ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+          mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
       sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
     } else if (isa<Float64Type>(computeType)) {
       sincosFunc = "__nv_sincos";
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 0de5636c27c3f..ca2792dd177e5 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -285,20 +285,9 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
 }
 
 //===----------------------------------------------------------------------===//
-// SinCosOp verifier and getShapeForUnroll
+// SinCosOp getShapeForUnroll
 //===----------------------------------------------------------------------===//
 
-LogicalResult math::SincosOp::verify() {
-  Type operandType = getOperand().getType();
-  Type sinType = getSin().getType();
-  Type cosType = getCos().getType();
-
-  if (operandType != sinType || operandType != cosType)
-    return emitOpError("result types must match operand type");
-
-  return success();
-}
-
 std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
   if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
     return llvm::to_vector<4>(vt.getShape());

>From 90ef640e7dd16880d700220c2ce41666325281b8 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 07:23:09 -0700
Subject: [PATCH 7/7] Spell out types; use dyn_cast on non-nullable

---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 12 ++++++------
 mlir/lib/Dialect/Math/IR/MathOps.cpp          |  2 +-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 73a003ef4e6c1..229e40e2061cb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -128,12 +128,12 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
   LogicalResult
   matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    const auto &typeConverter = *this->getTypeConverter();
-    auto loc = op.getLoc();
-    auto operandType = adaptor.getOperand().getType();
-    auto llvmOperandType = typeConverter.convertType(operandType);
-    auto sinType = typeConverter.convertType(op.getSin().getType());
-    auto cosType = typeConverter.convertType(op.getCos().getType());
+    const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+    mlir::Location loc = op.getLoc();
+    mlir::Type operandType = adaptor.getOperand().getType();
+    mlir::Type llvmOperandType = typeConverter.convertType(operandType);
+    mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
+    mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
     if (!llvmOperandType || !sinType || !cosType)
       return failure();
 
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index ca2792dd177e5..bbeef0f6ee9e5 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -289,7 +289,7 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
-  if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
+  if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
     return llvm::to_vector<4>(vt.getShape());
   return std::nullopt;
 }



More information about the Mlir-commits mailing list