[Mlir-commits] [mlir] [NVVM][MLIR] Refactor conversion of Math / Arith Operations seperate Passes (PR #180058)

Jason Van Beusekom llvmlistbot at llvm.org
Fri Feb 13 09:58:35 PST 2026


https://github.com/Jason-Van-Beusekom updated https://github.com/llvm/llvm-project/pull/180058

>From 1f39bcf0643a73c840c39b7ee8136d21f4e5cd7e Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Thu, 5 Feb 2026 16:24:15 -0600
Subject: [PATCH 1/4] [NVVM][MLIR] Refactor conversion of Math / Arith
 Operations seperate Passes

This Commit refactors the conversion of Math / Arith operations to NVVM into
a separate Pass called MathToNVVM. This was done to allow to support the
lowering of Math / Arith operations in flang. This mirrors what was done
in MathToROCDL.
---
 .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h |   6 -
 .../mlir/Conversion/MathToNVVM/MathToNVVM.h   |  28 ++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  14 +
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt  |   1 +
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        | 220 +-------------
 mlir/lib/Conversion/MathToNVVM/CMakeLists.txt |  26 ++
 mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp | 279 ++++++++++++++++++
 9 files changed, 352 insertions(+), 224 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
 create mode 100644 mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp

diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
index 48982ac6efe7c..9d85f04b40c72 100644
--- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
+++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h
@@ -36,12 +36,6 @@ void configureGpuToNVVMConversionLegality(ConversionTarget &target);
 /// GPU dialect to NVVM.
 void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter);
 
-/// Populate patterns that lower certain arith and math dialect ops to
-/// libdevice calls.
-void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
-                                         RewritePatternSet &patterns,
-                                         PatternBenefit benefit = 1);
-
 /// Collect a set of patterns to convert from the GPU dialect to NVVM.
 void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter,
                                          RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
new file mode 100644
index 0000000000000..e0e2b2c2e08c3
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -0,0 +1,28 @@
+//===- MathToNVVM.h - Utils to convert from the Math dialect to NVVM -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
+#define MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to NVVM
+/// libdevice calls.
+void populateMathToNVVMConversionPatterns(const LLVMTypeConverter &converter,
+                                          RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 7c2b450ca6710..a54b98004c3b6 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -49,6 +49,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
 #include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
 #include "mlir/Conversion/MathToXeVM/MathToXeVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1096338534416..fd9cbddbd7ab0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -827,6 +827,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
                         "Chipset that these operations will run on">];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToNVVM
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToNVVM : Pass<"convert-math-to-nvvm", "ModuleOp"> {
+  let summary = "Convert Math dialect to NVVM libdevice calls";
+  let description = [{
+    This pass converts supported Math ops to NVVM libdevice calls.
+  }];
+  let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
+                           "NVVM::NVVMDialect", "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 2ed10effb53da..e17988b12cade 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -39,6 +39,7 @@ add_subdirectory(MathToEmitC)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToNVVM)
 add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
 add_subdirectory(MathToXeVM)
diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
index 983aadf2c1517..681d788aa54dd 100644
--- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms
   MLIRGPUToGPURuntimeTransforms
   MLIRLLVMCommonConversion
   MLIRLLVMDialect
+  MLIRMathToNVVM
   MLIRMemRefToLLVM
   MLIRNVGPUDialect
   MLIRNVGPUToNVVM
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 5fdfc9fa8cdb6..4d963c1681511 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -456,229 +457,12 @@ void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) {
   });
 }
 
-struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
-  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value input = adaptor.getOperand();
-    Type inputType = input.getType();
-    auto convertedInput = maybeExt(input, rewriter);
-    auto computeType = convertedInput.getType();
-
-    StringRef sincosFunc;
-    if (isa<Float32Type>(computeType)) {
-      const arith::FastMathFlags flag = op.getFastmath();
-      const bool useApprox =
-          mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
-      sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
-    } else if (isa<Float64Type>(computeType)) {
-      sincosFunc = "__nv_sincos";
-    } else {
-      return rewriter.notifyMatchFailure(op,
-                                         "unsupported operand type for sincos");
-    }
-
-    auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
-
-    Value sinPtr, cosPtr;
-    {
-      OpBuilder::InsertionGuard guard(rewriter);
-      auto *scope =
-          op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
-      assert(scope && "Expected op to be inside automatic allocation scope");
-      rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
-      auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
-                                          rewriter.getI32IntegerAttr(1));
-      sinPtr =
-          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
-      cosPtr =
-          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
-    }
-
-    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
-                     op);
-
-    auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
-    auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
-
-    rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
-                            maybeTrunc(cosResult, inputType, rewriter)});
-    return success();
-  }
-
-private:
-  Value maybeExt(Value operand, PatternRewriter &rewriter) const {
-    if (isa<Float16Type, BFloat16Type>(operand.getType()))
-      return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
-                                   Float32Type::get(rewriter.getContext()),
-                                   operand);
-    return operand;
-  }
-
-  Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
-    if (operand.getType() != type)
-      return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
-    return operand;
-  }
-
-  void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
-                        StringRef funcName, Value input, Value sinPtr,
-                        Value cosPtr, Operation *op) const {
-    auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
-    auto ptrType = sinPtr.getType();
-
-    SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
-    auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
-
-    auto funcAttr = StringAttr::get(op->getContext(), funcName);
-    auto funcOp =
-        SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
-
-    if (!funcOp) {
-      auto parentFunc = op->getParentOfType<FunctionOpInterface>();
-      assert(parentFunc && "expected there to be a parent function");
-      OpBuilder b(parentFunc);
-
-      auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
-      funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
-    }
-
-    SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
-    LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
-  }
-};
-
-template <typename OpTy>
-static void populateOpPatterns(const LLVMTypeConverter &converter,
-                               RewritePatternSet &patterns,
-                               PatternBenefit benefit, StringRef f32Func,
-                               StringRef f64Func, StringRef f32ApproxFunc = "",
-                               StringRef f16Func = "") {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc, f16Func,
-                                           /*i32Func=*/"", benefit);
-}
-
-template <typename OpTy>
-static void populateIntOpPatterns(const LLVMTypeConverter &converter,
-                                  RewritePatternSet &patterns,
-                                  PatternBenefit benefit, StringRef i32Func) {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
-                                           benefit);
-}
-
-template <typename OpTy>
-static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
-                                       RewritePatternSet &patterns,
-                                       PatternBenefit benefit,
-                                       StringRef f32Func, StringRef f64Func) {
-  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
-                                           /*i32Func=*/"", benefit);
-}
-
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     PatternBenefit benefit) {
   patterns.add<GPUSubgroupReduceOpLowering>(converter, benefit);
 }
 
-void mlir::populateLibDeviceConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    PatternBenefit benefit) {
-  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
-                                    "__nv_fmod");
-  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
-                                       "__nv_fmaxf", "__nv_fmax");
-  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
-                                       "__nv_fminf", "__nv_fmin");
-
-  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
-  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
-                                   "__nv_fabs");
-  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
-                                   "__nv_acos");
-  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
-                                    "__nv_acosh");
-  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
-                                   "__nv_asin");
-  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
-                                    "__nv_asinh");
-  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
-                                   "__nv_atan");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
-                                    "__nv_atan2");
-  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
-                                    "__nv_atanh");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
-                                   "__nv_cbrt");
-  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
-                                   "__nv_ceil");
-  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
-                                       "__nv_copysignf", "__nv_copysign");
-  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
-                                  "__nv_cos", "__nv_fast_cosf");
-  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
-                                   "__nv_cosh");
-  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
-                                  "__nv_erf");
-  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
-                                   "__nv_erfc");
-  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
-                                  "__nv_exp", "__nv_fast_expf");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
-                                   "__nv_exp2");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
-                                    "__nv_expm1");
-  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
-                                    "__nv_floor");
-  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
-                                  "__nv_fma");
-  // Note: libdevice uses a different name for 32-bit finite checking
-  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
-                                       "__nv_finitef", "__nv_isfinited");
-  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
-                                    "__nv_isinfd");
-  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
-                                    "__nv_isnand");
-  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
-                                  "__nv_log", "__nv_fast_logf");
-  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
-                                    "__nv_log10", "__nv_fast_log10f");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
-                                    "__nv_log1p");
-  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
-                                   "__nv_log2", "__nv_fast_log2f");
-  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
-                                   "__nv_pow", "__nv_fast_powf");
-  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
-                                            "__nv_powif", "__nv_powi");
-  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
-                                    "__nv_round");
-  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
-                                        "__nv_rintf", "__nv_rint");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
-                                    "__nv_rsqrt");
-  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
-                                  "__nv_sin", "__nv_fast_sinf");
-  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
-                                   "__nv_sinh");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
-                                   "__nv_sqrt");
-  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
-                                  "__nv_tan", "__nv_fast_tanf");
-  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
-                                   "__nv_tanh");
-
-  // Custom pattern for sincos since it returns two values
-  patterns.add<SincosOpLowering>(converter, benefit);
-}
-
 void mlir::populateGpuToNVVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     PatternBenefit benefit) {
@@ -743,7 +527,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
                           NVVM::NVVMDialect::getClusterDimAttrName())},
       benefit);
 
-  populateLibDeviceConversionPatterns(converter, patterns, benefit);
+  populateMathToNVVMConversionPatterns(converter, patterns, benefit);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt b/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
new file mode 100644
index 0000000000000..589700d32646c
--- /dev/null
+++ b/mlir/lib/Conversion/MathToNVVM/CMakeLists.txt
@@ -0,0 +1,26 @@
+add_mlir_conversion_library(MLIRMathToNVVM
+  MathToNVVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToNVVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRNVVMDialect
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
new file mode 100644
index 0000000000000..5ef3c1fd7f1b4
--- /dev/null
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -0,0 +1,279 @@
+//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-nvvm"
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns,
+                               PatternBenefit benefit, StringRef f32Func,
+                               StringRef f64Func, StringRef f32ApproxFunc = "",
+                               StringRef f16Func = "") {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+                                           f32ApproxFunc, f16Func,
+                                           /*i32Func=*/"", benefit);
+}
+
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+                                  RewritePatternSet &patterns,
+                                  PatternBenefit benefit, StringRef i32Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
+                                           benefit);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+                                       RewritePatternSet &patterns,
+                                       PatternBenefit benefit,
+                                       StringRef f32Func, StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
+                                           /*i32Func=*/"", benefit);
+}
+
+// Custom pattern for sincos since it returns two values
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+  using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value input = adaptor.getOperand();
+    Type inputType = input.getType();
+    auto convertedInput = maybeExt(input, rewriter);
+    auto computeType = convertedInput.getType();
+
+    StringRef sincosFunc;
+    if (isa<Float32Type>(computeType)) {
+      const arith::FastMathFlags flag = op.getFastmath();
+      const bool useApprox =
+          mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
+      sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+    } else if (isa<Float64Type>(computeType)) {
+      sincosFunc = "__nv_sincos";
+    } else {
+      return rewriter.notifyMatchFailure(op,
+                                         "unsupported operand type for sincos");
+    }
+
+    auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+    Value sinPtr, cosPtr;
+    {
+      OpBuilder::InsertionGuard guard(rewriter);
+      auto *scope =
+          op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+      assert(scope && "Expected op to be inside automatic allocation scope");
+      rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+      auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+                                          rewriter.getI32IntegerAttr(1));
+      sinPtr =
+          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+      cosPtr =
+          LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+    }
+
+    createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+                     op);
+
+    auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+    auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
+
+    rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+                            maybeTrunc(cosResult, inputType, rewriter)});
+    return success();
+  }
+
+private:
+  Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+    if (isa<Float16Type, BFloat16Type>(operand.getType()))
+      return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+                                   Float32Type::get(rewriter.getContext()),
+                                   operand);
+    return operand;
+  }
+
+  Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+    if (operand.getType() != type)
+      return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
+    return operand;
+  }
+
+  void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+                        StringRef funcName, Value input, Value sinPtr,
+                        Value cosPtr, Operation *op) const {
+    auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+    auto ptrType = sinPtr.getType();
+
+    SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+    auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+    auto funcAttr = StringAttr::get(op->getContext(), funcName);
+    auto funcOp =
+        SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+    if (!funcOp) {
+      auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+      assert(parentFunc && "expected there to be a parent function");
+      OpBuilder b(parentFunc);
+
+      auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+      funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+    }
+
+    SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+    LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
+  }
+};
+
+void mlir::populateMathToNVVMConversionPatterns(
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit) {
+  populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
+                                    "__nv_fmod");
+  populateOpPatterns<arith::MaxNumFOp>(converter, patterns, benefit,
+                                       "__nv_fmaxf", "__nv_fmax");
+  populateOpPatterns<arith::MinNumFOp>(converter, patterns, benefit,
+                                       "__nv_fminf", "__nv_fmin");
+
+  populateIntOpPatterns<math::AbsIOp>(converter, patterns, benefit, "__nv_abs");
+  populateOpPatterns<math::AbsFOp>(converter, patterns, benefit, "__nv_fabsf",
+                                   "__nv_fabs");
+  populateOpPatterns<math::AcosOp>(converter, patterns, benefit, "__nv_acosf",
+                                   "__nv_acos");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, benefit, "__nv_acoshf",
+                                    "__nv_acosh");
+  populateOpPatterns<math::AsinOp>(converter, patterns, benefit, "__nv_asinf",
+                                   "__nv_asin");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, benefit, "__nv_asinhf",
+                                    "__nv_asinh");
+  populateOpPatterns<math::AtanOp>(converter, patterns, benefit, "__nv_atanf",
+                                   "__nv_atan");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, benefit, "__nv_atan2f",
+                                    "__nv_atan2");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, benefit, "__nv_atanhf",
+                                    "__nv_atanh");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, benefit, "__nv_cbrtf",
+                                   "__nv_cbrt");
+  populateOpPatterns<math::CeilOp>(converter, patterns, benefit, "__nv_ceilf",
+                                   "__nv_ceil");
+  populateOpPatterns<math::CopySignOp>(converter, patterns, benefit,
+                                       "__nv_copysignf", "__nv_copysign");
+  populateOpPatterns<math::CosOp>(converter, patterns, benefit, "__nv_cosf",
+                                  "__nv_cos", "__nv_fast_cosf");
+  populateOpPatterns<math::CoshOp>(converter, patterns, benefit, "__nv_coshf",
+                                   "__nv_cosh");
+  populateOpPatterns<math::ErfOp>(converter, patterns, benefit, "__nv_erff",
+                                  "__nv_erf");
+  populateOpPatterns<math::ErfcOp>(converter, patterns, benefit, "__nv_erfcf",
+                                   "__nv_erfc");
+  populateOpPatterns<math::ExpOp>(converter, patterns, benefit, "__nv_expf",
+                                  "__nv_exp", "__nv_fast_expf");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, benefit, "__nv_exp2f",
+                                   "__nv_exp2");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, benefit, "__nv_expm1f",
+                                    "__nv_expm1");
+  populateOpPatterns<math::FloorOp>(converter, patterns, benefit, "__nv_floorf",
+                                    "__nv_floor");
+  populateOpPatterns<math::FmaOp>(converter, patterns, benefit, "__nv_fmaf",
+                                  "__nv_fma");
+  // Note: libdevice uses a different name for 32-bit finite checking
+  populateOpPatterns<math::IsFiniteOp>(converter, patterns, benefit,
+                                       "__nv_finitef", "__nv_isfinited");
+  populateOpPatterns<math::IsInfOp>(converter, patterns, benefit, "__nv_isinff",
+                                    "__nv_isinfd");
+  populateOpPatterns<math::IsNaNOp>(converter, patterns, benefit, "__nv_isnanf",
+                                    "__nv_isnand");
+  populateOpPatterns<math::LogOp>(converter, patterns, benefit, "__nv_logf",
+                                  "__nv_log", "__nv_fast_logf");
+  populateOpPatterns<math::Log10Op>(converter, patterns, benefit, "__nv_log10f",
+                                    "__nv_log10", "__nv_fast_log10f");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, benefit, "__nv_log1pf",
+                                    "__nv_log1p");
+  populateOpPatterns<math::Log2Op>(converter, patterns, benefit, "__nv_log2f",
+                                   "__nv_log2", "__nv_fast_log2f");
+  populateOpPatterns<math::PowFOp>(converter, patterns, benefit, "__nv_powf",
+                                   "__nv_pow", "__nv_fast_powf");
+  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, benefit,
+                                            "__nv_powif", "__nv_powi");
+  populateOpPatterns<math::RoundOp>(converter, patterns, benefit, "__nv_roundf",
+                                    "__nv_round");
+  populateOpPatterns<math::RoundEvenOp>(converter, patterns, benefit,
+                                        "__nv_rintf", "__nv_rint");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, benefit, "__nv_rsqrtf",
+                                    "__nv_rsqrt");
+  populateOpPatterns<math::SinOp>(converter, patterns, benefit, "__nv_sinf",
+                                  "__nv_sin", "__nv_fast_sinf");
+  populateOpPatterns<math::SinhOp>(converter, patterns, benefit, "__nv_sinhf",
+                                   "__nv_sinh");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, benefit, "__nv_sqrtf",
+                                   "__nv_sqrt");
+  populateOpPatterns<math::TanOp>(converter, patterns, benefit, "__nv_tanf",
+                                  "__nv_tan", "__nv_fast_tanf");
+  populateOpPatterns<math::TanhOp>(converter, patterns, benefit, "__nv_tanhf",
+                                   "__nv_tanh");
+
+  // Custom pattern for sincos since it returns two values
+  patterns.add<SincosOpLowering>(converter, benefit);
+}
+
+namespace {
+struct ConvertMathToNVVMPass final
+    : impl::ConvertMathToNVVMBase<ConvertMathToNVVMPass> {
+  using impl::ConvertMathToNVVMBase<
+      ConvertMathToNVVMPass>::ConvertMathToNVVMBase;
+
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToNVVMPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+
+  populateMathToNVVMConversionPatterns(converter, patterns, /*benefit=*/1);
+
+  ConversionTarget target(getContext());
+  target
+      .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+                       LLVM::LLVMDialect, NVVM::NVVMDialect>();
+  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+                      LLVM::SqrtOp>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}

>From d2d6ae8f8afed94f945204a201bebe82703ddd7e Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Tue, 10 Feb 2026 16:19:10 -0600
Subject: [PATCH 2/4] update name based on feedback

---
 mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h   | 2 +-
 mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +-
 mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp          | 4 ++--
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
index e0e2b2c2e08c3..4100479c6e5dd 100644
--- a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -20,7 +20,7 @@ class Pass;
 
 /// Populate the given list with patterns that convert from Math to NVVM
 /// libdevice calls.
-void populateMathToNVVMConversionPatterns(const LLVMTypeConverter &converter,
+void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
                                           RewritePatternSet &patterns,
                                           PatternBenefit benefit = 1);
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4d963c1681511..660b24b071b49 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -527,7 +527,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
                           NVVM::NVVMDialect::getClusterDimAttrName())},
       benefit);
 
-  populateMathToNVVMConversionPatterns(converter, patterns, benefit);
+  populateLibDeviceConversionPatterns(converter, patterns, benefit);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
index 5ef3c1fd7f1b4..f9e0b32911912 100644
--- a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -155,7 +155,7 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
   }
 };
 
-void mlir::populateMathToNVVMConversionPatterns(
+void mlir::populateLibDeviceConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     PatternBenefit benefit) {
   populateOpPatterns<arith::RemFOp>(converter, patterns, benefit, "__nv_fmodf",
@@ -264,7 +264,7 @@ void ConvertMathToNVVMPass::runOnOperation() {
   LowerToLLVMOptions options(ctx, DataLayout(m));
   LLVMTypeConverter converter(ctx, options);
 
-  populateMathToNVVMConversionPatterns(converter, patterns, /*benefit=*/1);
+  populateLibDeviceConversionPatterns(converter, patterns, /*benefit=*/1);
 
   ConversionTarget target(getContext());
   target

>From 957d063e7a78fbb919687033165b6f187a6a110c Mon Sep 17 00:00:00 2001
From: jason-van-beusekom <jason.van-beusekom at hpe.com>
Date: Tue, 10 Feb 2026 16:22:40 -0600
Subject: [PATCH 3/4] format

---
 mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
index 4100479c6e5dd..e354d7dc4cbf1 100644
--- a/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
+++ b/mlir/include/mlir/Conversion/MathToNVVM/MathToNVVM.h
@@ -21,8 +21,8 @@ class Pass;
 /// Populate the given list with patterns that convert from Math to NVVM
 /// libdevice calls.
 void populateLibDeviceConversionPatterns(const LLVMTypeConverter &converter,
-                                          RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTONVVM_MATHTONVVM_H_

>From e8c6cd91efe6bf8e0347e6dea01b3c900b71745b Mon Sep 17 00:00:00 2001
From: Jason Van Beusekom <jason.van-beusekom at hpe.com>
Date: Fri, 13 Feb 2026 11:58:25 -0600
Subject: [PATCH 4/4] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Valentin Clement (バレンタイン クレメン) <clementval at gmail.com>
---
 mlir/include/mlir/Conversion/Passes.td        | 4 ++--
 mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index fd9cbddbd7ab0..37a8bf2f45d72 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -832,9 +832,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertMathToNVVM : Pass<"convert-math-to-nvvm", "ModuleOp"> {
-  let summary = "Convert Math dialect to NVVM libdevice calls";
+  let summary = "Convert Math dialect to CUDA libdevice calls";
   let description = [{
-    This pass converts supported Math ops to NVVM libdevice calls.
+    This pass converts supported Math ops to CUDA libdevice calls.
   }];
   let dependentDialects = ["arith::ArithDialect", "func::FuncDialect",
                            "NVVM::NVVMDialect", "vector::VectorDialect",
diff --git a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
index f9e0b32911912..80619f204df70 100644
--- a/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
+++ b/mlir/lib/Conversion/MathToNVVM/MathToNVVM.cpp
@@ -1,4 +1,4 @@
-//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//===-- MathToNVVM.cpp - conversion from Math to CUDA libdevice calls ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.



More information about the Mlir-commits mailing list