[Mlir-commits] [mlir] 82efd72 - [MLIR] Add sincos op to math dialect (#160772)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Sep 30 07:36:18 PDT 2025
Author: Asher Mancinelli
Date: 2025-09-30T07:36:13-07:00
New Revision: 82efd72ed505c6ec183eca700290a29051c2d6e6
URL: https://github.com/llvm/llvm-project/commit/82efd72ed505c6ec183eca700290a29051c2d6e6
DIFF: https://github.com/llvm/llvm-project/commit/82efd72ed505c6ec183eca700290a29051c2d6e6.diff
LOG: [MLIR] Add sincos op to math dialect (#160772)
Now that `sincos` is a supported intrinsic in the LLVM dialect
(#160561) we are able to add the corresponding operation in
the math dialect and add conversion patterns for LLVM and NVVM.
We have several benchmarks that use sine and cosine in hot-loops, and
saving some calculations by performing them together can benefit
performance. We would like to have a way to represent sincos in the math
dialect.
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Dialect/Math/IR/MathOps.cpp
mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
mlir/test/Dialect/Math/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index cfd8c4b8f11f7..af65af6fedec6 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -510,6 +510,43 @@ def Math_SinhOp : Math_FloatUnaryOp<"sinh"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// SinCosOp
+//===----------------------------------------------------------------------===//
+
+def Math_SincosOp : Math_Op<"sincos",
+ [SameOperandsAndResultShape,
+ DeclareOpInterfaceMethods<ArithFastMathInterface>,
+ AllTypesMatch<["operand", "sin", "cos"]>]> {
+ let summary = "sine and cosine of the specified value";
+ let description = [{
+ The `sincos` operation computes both the sine and cosine of a given value
+ simultaneously. It takes one operand of floating point type (i.e., scalar,
+ tensor or vector) and returns two results of the same type. This operation
+ can be more efficient than computing sine and cosine separately when both
+ values are needed.
+
+ Example:
+
+ ```mlir
+ // Scalar sine and cosine values.
+ %sin, %cos = math.sincos %input : f64
+ ```
+ }];
+
+ let arguments = (ins FloatLike:$operand,
+ DefaultValuedAttr<Arith_FastMathAttr,
+ "::mlir::arith::FastMathFlags::none">:$fastmath);
+ let results = (outs FloatLike:$sin, FloatLike:$cos);
+
+ let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)?
+ attr-dict `:` type($operand) }];
+
+ let extraClassDeclaration = [{
+ std::optional<SmallVector<int64_t, 4>> getShapeForUnroll();
+ }];
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index a95263bb55f69..852c50c965f11 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -436,7 +436,7 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
LLVM::FAbsOp, LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp,
LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp,
LLVM::RoundEvenOp, LLVM::RoundOp, LLVM::SinOp,
- LLVM::SqrtOp>();
+ LLVM::SincosOp, LLVM::SqrtOp>();
// TODO: Remove once we support replacing non-root ops.
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -466,6 +466,100 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value input = adaptor.getOperand();
+ Type inputType = input.getType();
+ auto convertedInput = maybeExt(input, rewriter);
+ auto computeType = convertedInput.getType();
+
+ StringRef sincosFunc;
+ if (isa<Float32Type>(computeType)) {
+ const arith::FastMathFlags flag = op.getFastmath();
+ const bool useApprox =
+ mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
+ sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+ } else if (isa<Float64Type>(computeType)) {
+ sincosFunc = "__nv_sincos";
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported operand type for sincos");
+ }
+
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ Value sinPtr, cosPtr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *scope =
+ op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+ auto one = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
+ sinPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ cosPtr =
+ rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0);
+ }
+
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+ op);
+
+ auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr);
+ auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr);
+
+ rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+ maybeTrunc(cosResult, inputType, rewriter)});
+ return success();
+ }
+
+private:
+ Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+ if (isa<Float16Type, BFloat16Type>(operand.getType()))
+ return rewriter.create<LLVM::FPExtOp>(
+ operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
+ return operand;
+ }
+
+ Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+ if (operand.getType() != type)
+ return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand);
+ return operand;
+ }
+
+ void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+ StringRef funcName, Value input, Value sinPtr,
+ Value cosPtr, Operation *op) const {
+ auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto ptrType = sinPtr.getType();
+
+ SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+ auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+ if (!funcOp) {
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+ funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ }
+
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+ rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands);
+ }
+};
+
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
@@ -589,6 +683,9 @@ void mlir::populateLibDeviceConversionPatterns(
"__nv_tan", "__nv_fast_tanf");
populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
"__nv_tanh");
+
+ // Custom pattern for sincos since it returns two values
+ patterns.add<SincosOpLowering>(converter, benefit);
}
void mlir::populateGpuToNVVMConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 853f45498ac52..229e40e2061cb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -121,6 +121,38 @@ using CountTrailingZerosOpLowering =
LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
+// A `sincos` is converted into `llvm.intr.sincos` followed by extractvalue ops.
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *this->getTypeConverter();
+ mlir::Location loc = op.getLoc();
+ mlir::Type operandType = adaptor.getOperand().getType();
+ mlir::Type llvmOperandType = typeConverter.convertType(operandType);
+ mlir::Type sinType = typeConverter.convertType(op.getSin().getType());
+ mlir::Type cosType = typeConverter.convertType(op.getCos().getType());
+ if (!llvmOperandType || !sinType || !cosType)
+ return failure();
+
+ ConvertFastMath<math::SincosOp, LLVM::SincosOp> attrs(op);
+
+ auto structType = LLVM::LLVMStructType::getLiteral(
+ rewriter.getContext(), {llvmOperandType, llvmOperandType});
+
+ auto sincosOp = rewriter.create<LLVM::SincosOp>(
+ loc, structType, adaptor.getOperand(), attrs.getAttrs());
+
+ auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0);
+ auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1);
+
+ rewriter.replaceOp(op, {sinValue, cosValue});
+ return success();
+ }
+};
+
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
@@ -393,6 +425,7 @@ void mlir::populateMathToLLVMConversionPatterns(
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
+ SincosOpLowering,
SinOpLowering,
SinhOpLowering,
ASinOpLowering,
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index a21631cbf8510..bbeef0f6ee9e5 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -284,6 +284,16 @@ OpFoldResult math::SinhOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// SinCosOp getShapeForUnroll
+//===----------------------------------------------------------------------===//
+
+std::optional<SmallVector<int64_t, 4>> math::SincosOp::getShapeForUnroll() {
+ if (auto vt = mlir::dyn_cast<VectorType>(getOperand().getType()))
+ return llvm::to_vector<4>(vt.getShape());
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// CountLeadingZerosOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ef06af3ad3163..a4b5dde8a2187 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1109,3 +1109,42 @@ gpu.module @test_module_55 {
func.return %result32, %result64 : f32, f64
}
}
+
+gpu.module @test_module_56 {
+ // CHECK: gpu.module @test_module_56
+
+ // CHECK-DAG: llvm.func @__nv_sincosf(f32, !llvm.ptr, !llvm.ptr)
+ // CHECK-DAG: llvm.func @__nv_sincos(f64, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 : f16
+ %sin32, %cos32 = math.sincos %arg_f32 : f32
+ %sin64, %cos64 = math.sincos %arg_f64 : f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+
+ // CHECK: llvm.func @__nv_fast_sincosf(f32, !llvm.ptr, !llvm.ptr)
+
+ // CHECK-LABEL: func @gpu_sincos_fastmath
+ // CHECK-SAME: %[[ARG_f16:.*]]: f16, %[[ARG_f32:.*]]: f32, %[[ARG_f64:.*]]: f64
+ func.func @gpu_sincos_fastmath(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f16, f32, f32, f64, f64) {
+ // CHECK-COUNT-6: llvm.alloca
+ // CHECK: %[[ARG_f16_ext:.*]] = llvm.fpext %[[ARG_f16]] : f16 to f32
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f16_ext]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK-COUNT-2: llvm.fptrunc
+ // CHECK: llvm.call @__nv_fast_sincosf(%[[ARG_f32]], %{{.+}}, %{{.+}}) : (f32, !llvm.ptr, !llvm.ptr) -> ()
+ // CHECK: llvm.call @__nv_sincos(%[[ARG_f64]], %{{.+}}, %{{.+}}) : (f64, !llvm.ptr, !llvm.ptr) -> ()
+ %sin16, %cos16 = math.sincos %arg_f16 fastmath<afn> : f16
+ %sin32, %cos32 = math.sincos %arg_f32 fastmath<afn> : f32
+ %sin64, %cos64 = math.sincos %arg_f64 fastmath<afn> : f64
+ func.return %sin16, %cos16, %sin32, %cos32, %sin64, %cos64 : f16, f16, f32, f32, f64, f64
+ }
+}
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index f4541220fe4d2..f7d27120d4207 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -230,6 +230,16 @@ func.func @trigonometrics(%arg0: f32) {
// -----
+// CHECK-LABEL: func @sincos
+// CHECK-SAME: [[ARG0:%.+]]: f32
+func.func @sincos(%arg0: f32) {
+ // CHECK: llvm.intr.sincos([[ARG0]]) : (f32) -> !llvm.struct<(f32, f32)>
+ %0:2 = math.sincos %arg0 : f32
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @inverse_trigonometrics
// CHECK-SAME: [[ARG0:%.+]]: f32
func.func @inverse_trigonometrics(%arg0: f32) {
diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir
index cb10fc4397ffc..f085d1c62ea86 100644
--- a/mlir/test/Dialect/Math/ops.mlir
+++ b/mlir/test/Dialect/Math/ops.mlir
@@ -62,6 +62,18 @@ func.func @sin(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
return
}
+// CHECK-LABEL: func @sincos(
+// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
+func.func @sincos(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
+ // CHECK: %{{.*}} = math.sincos %[[F]] : f32
+ %0:2 = math.sincos %f : f32
+ // CHECK: %{{.*}} = math.sincos %[[V]] : vector<4xf32>
+ %1:2 = math.sincos %v : vector<4xf32>
+ // CHECK: %{{.*}} = math.sincos %[[T]] : tensor<4x4x?xf32>
+ %2:2 = math.sincos %t : tensor<4x4x?xf32>
+ return
+}
+
// CHECK-LABEL: func @erf(
// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>)
func.func @erf(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) {
More information about the Mlir-commits
mailing list