[Mlir-commits] [mlir] [NVVM][MLIR] Refactor conversion of Math / Arith Operations seperate Passes (PR #180058)
Jason Van Beusekom
llvmlistbot at llvm.org
Fri Feb 13 09:58:35 PST 2026
https://github.com/Jason-Van-Beusekom updated https://github.com/llvm/llvm-project/pull/180058
>From 1f39bcf0643a73c840c39b7ee8136d21f4e5cd7e Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Thu, 5 Feb 2026 16:24:15 -0600
Subject: [PATCH 1/4] [NVVM][MLIR] Refactor conversion of Math / Arith
Operations seperate Passes
This Commit refactors the conversion of Math / Arith operations to NVVM into
a separate Pass called MathToNVVM. This was done to allow to support the
lowering of Math / Arith operations in flang. This mirrors what was done
in MathToROCDL.
---
.../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 6 -
.../mlir/Conversion/MathToNVVM/MathToNVVM.h | 28 ++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 14 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt | 1 +
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 220 +-------------
mlir/lib/Conversion/MathToNVVM/CMakeLists.txt | 26 ++
mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp | 279 ++++++++++++++++++
9 files changed, 352 insertions(+), 224 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
create mode 100644 mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 48982ac6efe7c..9d85f04b40c72 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -36,12 +36,6 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
/// GPU dialect to NVVM.
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
-/// Populate patterns that lower certain arith and math dialect ops to
-/// libdevice calls.
-void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
-
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
new file mode 100644
index 0000000000000..e0e2b2c2e08c3
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -0,0 +1,28 @@
+//===- MathToNVVM.h - Utils to convert from the Math dialect to NVVM -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
+#define MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to NVVM
+/// libdevice calls.
+void populateMathToNVVMConversionPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7c2b450ca6710..a54b98004c3b6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1096338534416..fd9cbddbd7ab0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -827,6 +827,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"Chipset that these operations will run on">];
}
+//===----------------------------------------------------------------------===//
+// MathToNVVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToNVVM : Pass<"convert-math-to-nvvm", "ModuleOp"> {
+ let summary = "Convert Math dialect to NVVM libdevice calls";
+ let description = [{
+ This pass converts supported Math ops to NVVM libdevice calls.
+ }];
+ let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+ "NVVM::NVVMDialect", "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 2ed10effb53da..e17988b12cade 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToNVVM)
add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MathToXeVM)
diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
index 983aadf2c1517..681d788aa54dd 100644
--- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
MLIRGPUToGPURuntimeTransforms
MLIRLLVMCommonConversion
MLIRLLVMDialect
+ MLIRMathToNVVM
MLIRMemRefToLLVM
MLIRNVGPUDialect
MLIRNVGPUToNVVM
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 5fdfc9fa8cdb6..4d963c1681511 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -456,229 +457,12 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
});
}
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
- using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value input = adaptor.getOperand();
- Type inputType = input.getType();
- auto convertedInput = maybeExt(input, rewriter);
- auto computeType = convertedInput.getType();
-
- StringRef sincosFunc;
- if (isa<Float32Type>(computeType)) {
- const arith::FastMathFlags flag = op.getFastmath();
- const bool useApprox =
- mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
- sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
- } else if (isa<Float64Type>(computeType)) {
- sincosFunc = "__nv_sincos";
- } else {
- return rewriter.notifyMatchFailure(op,
- "unsupported operand type for sincos");
- }
-
- auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
-
- Value sinPtr, cosPtr;
- {
- OpBuilder::InsertionGuard guard(rewriter);
- auto *scope =
- op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
- assert(scope && "Expected op to be inside automatic allocation scope");
- rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
- auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
- rewriter.getI32IntegerAttr(1));
- sinPtr =
- LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
- cosPtr =
- LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
- }
-
- createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
- op);
-
- auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
- auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
-
- rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
- maybeTrunc(cosResult, inputType, rewriter)});
- return success();
- }
-
-private:
- Value maybeExt(Value operand, PatternRewriter &rewriter) const {
- if (isa<Float16Type, BFloat16Type>(operand.getType()))
- return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
- Float32Type::get(rewriter.getContext()),
- operand);
- return operand;
- }
-
- Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
- if (operand.getType() != type)
- return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
- return operand;
- }
-
- void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
- StringRef funcName, Value input, Value sinPtr,
- Value cosPtr, Operation *op) const {
- auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
- auto ptrType = sinPtr.getType();
-
- SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
- auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
-
- auto funcAttr = StringAttr::get(op->getContext(), funcName);
- auto funcOp =
- SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
-
- if (!funcOp) {
- auto parentFunc = op->getParentOfType<FunctionOpInterface>();
- assert(parentFunc && "expected there to be a parent function");
- OpBuilder b(parentFunc);
-
- auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
- funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
- }
-
- SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
- LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
- }
-};
-
-template <typename OpTy>
-static void populateOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- PatternBenefit benefit, StringRef f32Func,
- StringRef f64Func, StringRef f32ApproxFunc = "",
- StringRef f16Func = "") {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc, f16Func,
- /*i32Func=*/"", benefit);
-}
-
-template <typename OpTy>
-static void populateIntOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- PatternBenefit benefit, StringRef i32Func) {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
- benefit);
-}
-
-template <typename OpTy>
-static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- PatternBenefit benefit,
- StringRef f32Func, StringRef f64Func) {
- patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
- /*i32Func=*/"", benefit);
-}
-
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
}
-void mlir::populateLibDeviceConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- PatternBenefit benefit) {
- populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
- "__nv_fmod");
- populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
- "__nv_fmaxf", "__nv_fmax");
- populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
- "__nv_fminf", "__nv_fmin");
-
- populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
- populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
- "__nv_fabs");
- populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
- "__nv_acos");
- populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
- "__nv_acosh");
- populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
- "__nv_asin");
- populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
- "__nv_asinh");
- populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
- "__nv_atan");
- populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
- "__nv_atan2");
- populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
- "__nv_atanh");
- populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
- "__nv_cbrt");
- populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
- "__nv_ceil");
- populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
- "__nv_copysignf", "__nv_copysign");
- populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
- "__nv_cos", "__nv_fast_cosf");
- populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
- "__nv_cosh");
- populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
- "__nv_erf");
- populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
- "__nv_erfc");
- populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
- "__nv_exp", "__nv_fast_expf");
- populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
- "__nv_exp2");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
- "__nv_expm1");
- populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
- "__nv_floor");
- populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
- "__nv_fma");
- // Note: libdevice uses a different name for 32-bit finite checking
- populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
- "__nv_finitef", "__nv_isfinited");
- populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
- "__nv_isinfd");
- populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
- "__nv_isnand");
- populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
- "__nv_log", "__nv_fast_logf");
- populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
- "__nv_log10", "__nv_fast_log10f");
- populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
- "__nv_log1p");
- populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
- "__nv_log2", "__nv_fast_log2f");
- populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
- "__nv_pow", "__nv_fast_powf");
- populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
- "__nv_powif", "__nv_powi");
- populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
- "__nv_round");
- populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
- "__nv_rintf", "__nv_rint");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
- "__nv_rsqrt");
- populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
- "__nv_sin", "__nv_fast_sinf");
- populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
- "__nv_sinh");
- populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
- "__nv_sqrt");
- populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
- "__nv_tan", "__nv_fast_tanf");
- populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
- "__nv_tanh");
-
- // Custom pattern for sincos since it returns two values
- patterns.add<SincosOpLowering>(converter, benefit);
-}
-
void mlir::populateGpuToNVVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
PatternBenefit benefit) {
@@ -743,7 +527,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
NVVM::NVVMDialect::getClusterDimAttrName())},
benefit);
- populateLibDeviceConversionPatterns(converter, patterns, benefit);
+ populateMathToNVVMConversionPatterns(converter, patterns, benefit);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt b/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
new file mode 100644
index 0000000000000..589700d32646c
--- /dev/null
+++ b/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
@@ -0,0 +1,26 @@
+add_mlir_conversion_library(MLIRMathToNVVM
+ MathToNVVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToNVVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRNVVMDialect
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
new file mode 100644
index 0000000000000..5ef3c1fd7f1b4
--- /dev/null
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -0,0 +1,279 @@
+//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-nvvm"
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef f32Func,
+ StringRef f64Func, StringRef f32ApproxFunc = "",
+ StringRef f16Func = "") {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32ApproxFunc, f16Func,
+ /*i32Func=*/"", benefit);
+}
+
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef i32Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
+ benefit);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit,
+ StringRef f32Func, StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
+ /*i32Func=*/"", benefit);
+}
+
+// Custom pattern for sincos since it returns two values
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value input = adaptor.getOperand();
+ Type inputType = input.getType();
+ auto convertedInput = maybeExt(input, rewriter);
+ auto computeType = convertedInput.getType();
+
+ StringRef sincosFunc;
+ if (isa<Float32Type>(computeType)) {
+ const arith::FastMathFlags flag = op.getFastmath();
+ const bool useApprox =
+ mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
+ sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+ } else if (isa<Float64Type>(computeType)) {
+ sincosFunc = "__nv_sincos";
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported operand type for sincos");
+ }
+
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ Value sinPtr, cosPtr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *scope =
+ op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(1));
+ sinPtr =
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+ cosPtr =
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+ }
+
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+ op);
+
+ auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+ auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
+
+ rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+ maybeTrunc(cosResult, inputType, rewriter)});
+ return success();
+ }
+
+private:
+ Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+ if (isa<Float16Type, BFloat16Type>(operand.getType()))
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
+ return operand;
+ }
+
+ Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+ if (operand.getType() != type)
+ return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
+ return operand;
+ }
+
+ void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+ StringRef funcName, Value input, Value sinPtr,
+ Value cosPtr, Operation *op) const {
+ auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto ptrType = sinPtr.getType();
+
+ SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+ auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+ if (!funcOp) {
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+ funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ }
+
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+ LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
+ }
+};
+
+void mlir::populateMathToNVVMConversionPatterns(
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit) {
+ populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
+ "__nv_fmod");
+ populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
+ "__nv_fmaxf", "__nv_fmax");
+ populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
+ "__nv_fminf", "__nv_fmin");
+
+ populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
+ populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
+ "__nv_fabs");
+ populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
+ "__nv_acos");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
+ "__nv_acosh");
+ populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
+ "__nv_asin");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
+ "__nv_asinh");
+ populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
+ "__nv_atan");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
+ "__nv_atan2");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
+ "__nv_atanh");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
+ "__nv_cbrt");
+ populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
+ "__nv_ceil");
+ populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+ "__nv_copysignf", "__nv_copysign");
+ populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
+ "__nv_cos", "__nv_fast_cosf");
+ populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
+ "__nv_cosh");
+ populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
+ "__nv_erf");
+ populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
+ "__nv_erfc");
+ populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
+ "__nv_exp", "__nv_fast_expf");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
+ "__nv_exp2");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
+ "__nv_expm1");
+ populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
+ "__nv_floor");
+ populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
+ "__nv_fma");
+ // Note: libdevice uses a different name for 32-bit finite checking
+ populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
+ "__nv_finitef", "__nv_isfinited");
+ populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
+ "__nv_isinfd");
+ populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
+ "__nv_isnand");
+ populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
+ "__nv_log", "__nv_fast_logf");
+ populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
+ "__nv_log10", "__nv_fast_log10f");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
+ "__nv_log1p");
+ populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
+ "__nv_log2", "__nv_fast_log2f");
+ populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
+ "__nv_pow", "__nv_fast_powf");
+ populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
+ "__nv_powif", "__nv_powi");
+ populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
+ "__nv_round");
+ populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
+ "__nv_rintf", "__nv_rint");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
+ "__nv_rsqrt");
+ populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
+ "__nv_sin", "__nv_fast_sinf");
+ populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
+ "__nv_sinh");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
+ "__nv_sqrt");
+ populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
+ "__nv_tan", "__nv_fast_tanf");
+ populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
+ "__nv_tanh");
+
+ // Custom pattern for sincos since it returns two values
+ patterns.add<SincosOpLowering>(converter, benefit);
+}
+
+namespace {
+struct ConvertMathToNVVMPass final
+ : impl::ConvertMathToNVVMBase<ConvertMathToNVVMPass> {
+ using impl::ConvertMathToNVVMBase<
+ ConvertMathToNVVMPass>::ConvertMathToNVVMBase;
+
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToNVVMPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+
+ populateMathToNVVMConversionPatterns(converter, patterns, /*benefit=*/1);
+
+ ConversionTarget target(getContext());
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, NVVM::NVVMDialect>();
+ target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+ LLVM::SqrtOp>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
>From d2d6ae8f8afed94f945204a201bebe82703ddd7e Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Tue, 10 Feb 2026 16:19:10 -0600
Subject: [PATCH 2/4] update name based on feedback
---
mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h | 2 +-
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +-
mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp | 4 ++--
3 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
index e0e2b2c2e08c3..4100479c6e5dd 100644
--- a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -20,7 +20,7 @@ class Pass;
/// Populate the given list with patterns that convert from Math to NVVM
/// libdevice calls.
-void populateMathToNVVMConversionPatterns(const LLVMTypeConverter &converter,
+void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
PatternBenefit benefit = 1);
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4d963c1681511..660b24b071b49 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -527,7 +527,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
NVVM::NVVMDialect::getClusterDimAttrName())},
benefit);
- populateMathToNVVMConversionPatterns(converter, patterns, benefit);
+ populateLibDeviceConversionPatterns(converter, patterns, benefit);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
index 5ef3c1fd7f1b4..f9e0b32911912 100644
--- a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -155,7 +155,7 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
}
};
-void mlir::populateMathToNVVMConversionPatterns(
+void mlir::populateLibDeviceConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
PatternBenefit benefit) {
populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
@@ -264,7 +264,7 @@ void ConvertMathToNVVMPass::runOnOperation() {
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToNVVMConversionPatterns(converter, patterns, /*benefit=*/1);
+ populateLibDeviceConversionPatterns(converter, patterns, /*benefit=*/1);
ConversionTarget target(getContext());
target
>From 957d063e7a78fbb919687033165b6f187a6a110c Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Tue, 10 Feb 2026 16:22:40 -0600
Subject: [PATCH 3/4] format
---
mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
index 4100479c6e5dd..e354d7dc4cbf1 100644
--- a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -21,8 +21,8 @@ class Pass;
/// Populate the given list with patterns that convert from Math to NVVM
/// libdevice calls.
void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
- RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+ RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
>From e8c6cd91efe6bf8e0347e6dea01b3c900b71745b Mon Sep 17 00:00:00 2001
From: Jason Van Beusekom <jason.van-beusekom at hpe.com>
Date: Fri, 13 Feb 2026 11:58:25 -0600
Subject: [PATCH 4/4] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Valentin Clement (バレンタイン クレメン) <clementval at gmail.com>
---
mlir/include/mlir/Conversion/Passes.td | 4 ++--
mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fd9cbddbd7ab0..37a8bf2f45d72 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -832,9 +832,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
//===----------------------------------------------------------------------===//
def ConvertMathToNVVM : Pass<"convert-math-to-nvvm", "ModuleOp"> {
- let summary = "Convert Math dialect to NVVM libdevice calls";
+ let summary = "Convert Math dialect to CUDA libdevice calls";
let description = [{
- This pass converts supported Math ops to NVVM libdevice calls.
+ This pass converts supported Math ops to CUDA libdevice calls.
}];
let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
"NVVM::NVVMDialect", "vector::VectorDialect",
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
index f9e0b32911912..80619f204df70 100644
--- a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -1,4 +1,4 @@
-//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//===-- MathToNVVM.cpp - conversion from Math to CUDA libdevice calls ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
More information about the Mlir-commits
mailing list