[Mlir-commits] [mlir] [MLIR] Add sincos op to math dialect (PR #160772)
Asher Mancinelli
llvmlistbot at llvm.org
Tue Sep 30 07:23:49 PDT 2025
https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/160772
>From 3fcbaee738982ab6df9df883dbf0fdc63302862e Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 08:48:11 -0700
Subject: [PATCH 1/7] [MLIR] Add sincos operation to math dialect
Now that `sincos` is a supported intrinsic in the LLVM dialect (https://github.com/llvm/llvm-project/pull/160561) we are able to add the corresponding operation in the math dialect.
We have several benchmarks that use sine and cosine in hot-loops, and saving some calculations by performing sine and cosine together can benefit performance. We would like to have a way to represent sincos in the math dialect.
Two parts I'm unsure about:
* What do we think of the assembly format? `math.sincos %floatlike : f32 -> f32, f32` With a custom assembly format we could omit the `->` and everything after, but I couldn't get the ODS to do that. Open to suggestions.
* I implement `getShapeForUnroll()` here, but where is the best place to test the unroller interfaces? I'll keep poking around after sending this out for review.
---
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 38 ++++++++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 93 ++++++++++++++++++-
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 33 +++++++
mlir/lib/Dialect/Math/IR/MathOps.cpp | 22 +++++
.../Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 39 ++++++++
.../Conversion/MathToLLVM/math-to-llvm.mlir | 10 ++
mlir/test/Dialect/Math/ops.mlir | 12 +++
7 files changed, 246 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index cfd8c4b8f11f7..a7e79f2efd4c5 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -510,6 +510,44 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SinCosOp
+//===----------------------------------------------------------------------===//
+
+def Math_SincosOp : Math_Op<"sincos",
+ [SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ let summary = "sine and cosine of the specified value";
+ let description = [{
+ The `sincos` operation computes both the sine and cosine of a given value
+ simultaneously. It takes one operand of floating point type (i.e., scalar,
+ tensor or vector) and returns two results of the same type. This operation
+ can be more efficient than computing sine and cosine separately when both
+ values are needed.
+
+ Example:
+
+ ```mlir
+ // Scalar sine and cosine values.
+ %sin, %cos = math.sincos %input : f64 `->` f64, f64
+ ```
+ }];
+
+ let arguments = (ins FloatLike:$operand,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let results = (outs FloatLike:$sin, FloatLike:$cos);
+
+ let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }];
+
+ let extraClassDeclaration = [{
+ std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index a95263bb55f69..16d765f2b2561 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ LLVM::SincosOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,94 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}
+// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value input = adaptor.getOperand();
+ Type inputType = input.getType();
+ auto convertedInput = maybeExt(input, rewriter);
+ auto computeType = convertedInput.getType();
+
+ StringRef sincosFunc;
+ if (isa<Float32Type>(computeType)) {
+ const arith::FastMathFlags flag = op.getFastmath();
+ const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+ sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+ } else if (isa<Float64Type>(computeType)) {
+ sincosFunc = "__nv_sincos";
+ } else {
+ return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos");
+ }
+
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ Value sinPtr, cosPtr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *scope = op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+ auto one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ sinPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ cosPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ }
+
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op);
+
+ auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
+ auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+
+ rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+ maybeTrunc(cosResult, inputType, rewriter)});
+ return success();
+ }
+
+private:
+ Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+ if (isa<Float16Type, BFloat16Type>(operand.getType())) {
+ return rewriter.create<LLVM::FPExtOp>(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ }
+ return operand;
+ }
+
+ Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+ if (operand.getType() != type)
+ return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return operand;
+ }
+
+ void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+ StringRef funcName, Value input, Value sinPtr, Value cosPtr,
+ Operation *op) const {
+ auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto ptrType = sinPtr.getType();
+
+ SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+ auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+ if (!funcOp) {
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+ funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ }
+
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+ rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ }
+};
+
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
@@ -589,6 +677,9 @@ void mlir::populateLibDeviceConversionPatterns(
"__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
"__nv_tanh");
+
+ // Custom pattern for sincos since it returns two values
+ patterns.add<SincosOpLowering>(converter, benefit);
}
void mlir::populateGpuToNVVMConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 853f45498ac52..73a003ef4e6c1 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
+// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
+ auto loc = op.getLoc();
+ auto operandType = adaptor.getOperand().getType();
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ auto sinType = typeConverter.convertType(op.getSin().getType());
+ auto cosType = typeConverter.convertType(op.getCos().getType());
+ if (!llvmOperandType || !sinType || !cosType)
+ return failure();
+
+ ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
+
+ auto structType = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {llvmOperandType, llvmOperandType});
+
+ auto sincosOp = rewriter.create<LLVM::SincosOp>(
+ loc, structType, adaptor.getOperand(), attrs.getAttrs());
+
+ auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
+ auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
+
+ rewriter.replaceOp(op, {sinValue, cosValue});
+ return success();
+ }
+};
+
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
+ SincosOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index a21631cbf8510..f0bf62770d4cc 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -284,6 +284,28 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// SinCosOp verifier and getShapeForUnroll
+//===----------------------------------------------------------------------===//
+
+LogicalResult math::SincosOp::verify() {
+ Type operandType = getOperand().getType();
+ Type sinType = getSin().getType();
+ Type cosType = getCos().getType();
+
+ if (operandType != sinType || operandType != cosType) {
+ return emitOpError("result types must match operand type");
+ }
+
+ return success();
+}
+
+std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
+ if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
+ return llvm::to_vector<4>(vt.getShape());
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ef06af3ad3163..cdefc4d6098c7 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
func.return %result32, %result64 : f32, f64
}
}
+
+gpu.module @test_module_56 {
+ // CHECK: gpu.module @test_module_56
+
+ // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
+ // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16
+ %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32
+ %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+
+ // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos_fastmath
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16 -> f16, f16
+ %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32 -> f32, f32
+ %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64 -> f64, f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f4541220fe4d2..9030ba9c93e55 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {
// -----
+// CHECK-LABEL: func @sincos
+// CHECK-SAME: [[ARG0:%.+]]: f32
+func.func @sincos(%arg0: f32) {
+ // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
+ %0:2 = math.sincos %arg0 : f32 -> f32, f32
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @inverse_trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics(%arg0: f32) {
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index cb10fc4397ffc..5d3a8a6d87bed 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
+// CHECK-LABEL: func @sincos(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.sincos %[[F]] : f32
+ %0:2 = math.sincos %f : f32 -> f32, f32
+ // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
+ %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32>
+ // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
+ %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32>
+ return
+}
+
// CHECK-LABEL: func @erf(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
>From 6e2b34c11fd688312bcab520d4e74b5e12e10ae9 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:24:24 -0700
Subject: [PATCH 2/7] Formatting
---
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 40 +++++++++++--------
1 file changed, 24 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 16d765f2b2561..2c0a3305518e1 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -482,29 +482,35 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
StringRef sincosFunc;
if (isa<Float32Type>(computeType)) {
const arith::FastMathFlags flag = op.getFastmath();
- const bool useApprox = ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+ const bool useApprox =
+ ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
} else if (isa<Float64Type>(computeType)) {
sincosFunc = "__nv_sincos";
} else {
- return rewriter.notifyMatchFailure(op, "unsupported operand type for sincos");
+ return rewriter.notifyMatchFailure(op,
+ "unsupported operand type for sincos");
}
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
-
+
Value sinPtr, cosPtr;
{
OpBuilder::InsertionGuard guard(rewriter);
- auto *scope = op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ auto *scope =
+ op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
assert(scope && "Expected op to be inside automatic allocation scope");
rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
auto one = rewriter.create<LLVM::ConstantOp>(
loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
- sinPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
- cosPtr = rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ sinPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ cosPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
}
- createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op);
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+ op);
auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
@@ -517,7 +523,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
if (isa<Float16Type, BFloat16Type>(operand.getType())) {
- return rewriter.create<LLVM::FPExtOp>(operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return rewriter.create<LLVM::FPExtOp>(
+ operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
}
return operand;
}
@@ -529,26 +536,27 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
}
void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
- StringRef funcName, Value input, Value sinPtr, Value cosPtr,
- Operation *op) const {
+ StringRef funcName, Value input, Value sinPtr,
+ Value cosPtr, Operation *op) const {
auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
auto ptrType = sinPtr.getType();
-
+
SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
-
+
auto funcAttr = StringAttr::get(op->getContext(), funcName);
- auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
-
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
if (!funcOp) {
auto parentFunc = op->getParentOfType<FunctionOpInterface>();
assert(parentFunc && "expected there to be a parent function");
OpBuilder b(parentFunc);
-
+
auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
}
-
+
SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
}
>From dfea012c6c10386620c341eba82a690af926e969 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:26:32 -0700
Subject: [PATCH 3/7] Remove needless comment
---
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2c0a3305518e1..2b46a01c3b0e5 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -466,7 +466,6 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}
-// Custom lowering for math.sincos to __nv_sincosf/__nv_sincos libdevice calls
struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
>From 1a24ecca6e955b70c09f11bccc4ee1bc6b41a1fc Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 13:43:30 -0700
Subject: [PATCH 4/7] Remove braces on single-line if
---
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +--
mlir/lib/Dialect/Math/IR/MathOps.cpp | 3 +--
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2b46a01c3b0e5..f8f2104d2bd6a 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -521,10 +521,9 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
private:
Value maybeExt(Value operand, PatternRewriter &rewriter) const {
- if (isa<Float16Type, BFloat16Type>(operand.getType())) {
+ if (isa<Float16Type, BFloat16Type>(operand.getType()))
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
- }
return operand;
}
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index f0bf62770d4cc..0de5636c27c3f 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -293,9 +293,8 @@ LogicalResult math::SincosOp::verify() {
Type sinType = getSin().getType();
Type cosType = getCos().getType();
- if (operandType != sinType || operandType != cosType) {
+ if (operandType != sinType || operandType != cosType)
return emitOpError("result types must match operand type");
- }
return success();
}
>From 145610739e99193970accde6f4e9596eb7fe4f3b Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 14:18:40 -0700
Subject: [PATCH 5/7] Refine assembly format
---
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 7 ++++---
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 12 ++++++------
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir | 2 +-
mlir/test/Dialect/Math/ops.mlir | 6 +++---
4 files changed, 14 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index a7e79f2efd4c5..b4212056694e9 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -516,7 +516,8 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
def Math_SincosOp : Math_Op<"sincos",
[SameOperandsAndResultShape,
- DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ AllTypesMatch<["operand", "sin", "cos"]>]> {
let summary = "sine and cosine of the specified value";
let description = [{
The `sincos` operation computes both the sine and cosine of a given value
@@ -529,7 +530,7 @@ def Math_SincosOp : Math_Op<"sincos",
```mlir
// Scalar sine and cosine values.
- %sin, %cos = math.sincos %input : f64 `->` f64, f64
+ %sin, %cos = math.sincos %input : f64
```
}];
@@ -539,7 +540,7 @@ def Math_SincosOp : Math_Op<"sincos",
let results = (outs FloatLike:$sin, FloatLike:$cos);
let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
- attr-dict `:` type($operand) `->` type($sin) `,` type($cos) }];
+ attr-dict `:` type($operand) }];
let extraClassDeclaration = [{
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index cdefc4d6098c7..a4b5dde8a2187 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1125,9 +1125,9 @@ gpu.module @test_module_56 {
// CHECK-COUNT-2: llvm.fptrunc
// CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
- %sin16, %cos16 = math.sincos %arg_f16 : f16 -> f16, f16
- %sin32, %cos32 = math.sincos %arg_f32 : f32 -> f32, f32
- %sin64, %cos64 = math.sincos %arg_f64 : f64 -> f64, f64
+ %sin16, %cos16 = math.sincos %arg_f16 : f16
+ %sin32, %cos32 = math.sincos %arg_f32 : f32
+ %sin64, %cos64 = math.sincos %arg_f64 : f64
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
}
@@ -1142,9 +1142,9 @@ gpu.module @test_module_56 {
// CHECK-COUNT-2: llvm.fptrunc
// CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
- %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16 -> f16, f16
- %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32 -> f32, f32
- %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64 -> f64, f64
+ %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16
+ %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32
+ %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64
func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
}
}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 9030ba9c93e55..f7d27120d4207 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -234,7 +234,7 @@ func.func @trigonometrics(%arg0: f32) {
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @sincos(%arg0: f32) {
// CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
- %0:2 = math.sincos %arg0 : f32 -> f32, f32
+ %0:2 = math.sincos %arg0 : f32
func.return
}
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index 5d3a8a6d87bed..f085d1c62ea86 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -66,11 +66,11 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
// CHECK: %{{.*}} = math.sincos %[[F]] : f32
- %0:2 = math.sincos %f : f32 -> f32, f32
+ %0:2 = math.sincos %f : f32
// CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
- %1:2 = math.sincos %v : vector<4xf32> -> vector<4xf32>, vector<4xf32>
+ %1:2 = math.sincos %v : vector<4xf32>
// CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
- %2:2 = math.sincos %t : tensor<4x4x?xf32> -> tensor<4x4x?xf32>, tensor<4x4x?xf32>
+ %2:2 = math.sincos %t : tensor<4x4x?xf32>
return
}
>From faa84a92b9baa9a1ed4a184abc84b4db44b329c4 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 25 Sep 2025 17:22:51 -0700
Subject: [PATCH 6/7] Remove custom verifier; clean up FMF handling
---
mlir/include/mlir/Dialect/Math/IR/MathOps.td | 2 --
.../Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +-
mlir/lib/Dialect/Math/IR/MathOps.cpp | 13 +------------
3 files changed, 2 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index b4212056694e9..af65af6fedec6 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -545,8 +545,6 @@ def Math_SincosOp : Math_Op<"sincos",
let extraClassDeclaration = [{
std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
}];
-
- let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index f8f2104d2bd6a..852c50c965f11 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -482,7 +482,7 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
if (isa<Float32Type>(computeType)) {
const arith::FastMathFlags flag = op.getFastmath();
const bool useApprox =
- ((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag);
+ mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
} else if (isa<Float64Type>(computeType)) {
sincosFunc = "__nv_sincos";
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 0de5636c27c3f..ca2792dd177e5 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -285,20 +285,9 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
}
//===----------------------------------------------------------------------===//
-// SinCosOp verifier and getShapeForUnroll
+// SinCosOp getShapeForUnroll
//===----------------------------------------------------------------------===//
-LogicalResult math::SincosOp::verify() {
- Type operandType = getOperand().getType();
- Type sinType = getSin().getType();
- Type cosType = getCos().getType();
-
- if (operandType != sinType || operandType != cosType)
- return emitOpError("result types must match operand type");
-
- return success();
-}
-
std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
return llvm::to_vector<4>(vt.getShape());
>From 90ef640e7dd16880d700220c2ce41666325281b8 Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Tue, 30 Sep 2025 07:23:09 -0700
Subject: [PATCH 7/7] Spell out types; use dyn_cast on non-nullable
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 12 ++++++------
mlir/lib/Dialect/Math/IR/MathOps.cpp | 2 +-
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 73a003ef4e6c1..229e40e2061cb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -128,12 +128,12 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
LogicalResult
matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- const auto &typeConverter = *this->getTypeConverter();
- auto loc = op.getLoc();
- auto operandType = adaptor.getOperand().getType();
- auto llvmOperandType = typeConverter.convertType(operandType);
- auto sinType = typeConverter.convertType(op.getSin().getType());
- auto cosType = typeConverter.convertType(op.getCos().getType());
+ const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+ mlir::Location loc = op.getLoc();
+ mlir::Type operandType = adaptor.getOperand().getType();
+ mlir::Type llvmOperandType = typeConverter.convertType(operandType);
+ mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
+ mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
if (!llvmOperandType || !sinType || !cosType)
return failure();
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index ca2792dd177e5..bbeef0f6ee9e5 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -289,7 +289,7 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
- if (auto vt = mlir::dyn_cast_or_null<VectorType>(getOperand().getType()))
+ if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
return llvm::to_vector<4>(vt.getShape());
return std::nullopt;
}
More information about the Mlir-commits
mailing list