[Mlir-commits] [mlir] [NVVM][MLIR] Refactor conversion of Math / Arith Operations seperate Passes (PR #180058)
Jason Van Beusekom
llvmlistbot at llvm.org
Tue Feb 10 10:27:29 PST 2026
================
@@ -0,0 +1,279 @@
+//===-- MathToNVVM.cpp - conversion from Math to NVVM libdevice calls ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToNVVM/MathToNVVM.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/Pass/Pass.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTONVVM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-nvvm"
+
+template <typename OpTy>
+static void populateOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef f32Func,
+ StringRef f64Func, StringRef f32ApproxFunc = "",
+ StringRef f16Func = "") {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
+ f32ApproxFunc, f16Func,
+ /*i32Func=*/"", benefit);
+}
+
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit, StringRef i32Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, "", "", "", "", i32Func,
+ benefit);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ PatternBenefit benefit,
+ StringRef f32Func, StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter, benefit);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func, "", "",
+ /*i32Func=*/"", benefit);
+}
+
+// Custom pattern for sincos since it returns two values
+struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> {
+ using ConvertOpToLLVMPattern<math::SincosOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::SincosOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ Value input = adaptor.getOperand();
+ Type inputType = input.getType();
+ auto convertedInput = maybeExt(input, rewriter);
+ auto computeType = convertedInput.getType();
+
+ StringRef sincosFunc;
+ if (isa<Float32Type>(computeType)) {
+ const arith::FastMathFlags flag = op.getFastmath();
+ const bool useApprox =
+ mlir::arith::bitEnumContainsAny(flag, arith::FastMathFlags::afn);
+ sincosFunc = useApprox ? "__nv_fast_sincosf" : "__nv_sincosf";
+ } else if (isa<Float64Type>(computeType)) {
+ sincosFunc = "__nv_sincos";
+ } else {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported operand type for sincos");
+ }
+
+ auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
+
+ Value sinPtr, cosPtr;
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto *scope =
+ op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>();
+ assert(scope && "Expected op to be inside automatic allocation scope");
+ rewriter.setInsertionPointToStart(&scope->getRegion(0).front());
+ auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
+ rewriter.getI32IntegerAttr(1));
+ sinPtr =
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+ cosPtr =
+ LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0);
+ }
+
+ createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr,
+ op);
+
+ auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr);
+ auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr);
+
+ rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter),
+ maybeTrunc(cosResult, inputType, rewriter)});
+ return success();
+ }
+
+private:
+ Value maybeExt(Value operand, PatternRewriter &rewriter) const {
+ if (isa<Float16Type, BFloat16Type>(operand.getType()))
+ return LLVM::FPExtOp::create(rewriter, operand.getLoc(),
+ Float32Type::get(rewriter.getContext()),
+ operand);
+ return operand;
+ }
+
+ Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const {
+ if (operand.getType() != type)
+ return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand);
+ return operand;
+ }
+
+ void createSincosCall(ConversionPatternRewriter &rewriter, Location loc,
+ StringRef funcName, Value input, Value sinPtr,
+ Value cosPtr, Operation *op) const {
+ auto voidType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ auto ptrType = sinPtr.getType();
+
+ SmallVector<Type> operandTypes = {input.getType(), ptrType, ptrType};
+ auto funcType = LLVM::LLVMFunctionType::get(voidType, operandTypes);
+
+ auto funcAttr = StringAttr::get(op->getContext(), funcName);
+ auto funcOp =
+ SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(op, funcAttr);
+
+ if (!funcOp) {
+ auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+ assert(parentFunc && "expected there to be a parent function");
+ OpBuilder b(parentFunc);
+
+ auto globalloc = loc->findInstanceOfOrUnknown<FileLineColLoc>();
+ funcOp = LLVM::LLVMFuncOp::create(b, globalloc, funcName, funcType);
+ }
+
+ SmallVector<Value> callOperands = {input, sinPtr, cosPtr};
+ LLVM::CallOp::create(rewriter, loc, funcOp, callOperands);
+ }
+};
+
+void mlir::populateMathToNVVMConversionPatterns(
----------------
Jason-Van-Beusekom wrote:
@clementval commented in https://github.com/llvm/llvm-project/pull/180060
"populateLibDeviceConversionPatterns might fit better since most of these are actual libdevice function calls."
https://github.com/llvm/llvm-project/pull/180058
More information about the Mlir-commits
mailing list