[Mlir-commits] [mlir] [MLIR][Math][GPU] Add lowering of absi and fpowi to libdevice (PR #123422)

William Moses llvmlistbot at llvm.org
Sun Jan 19 12:19:13 PST 2025


https://github.com/wsmoses updated https://github.com/llvm/llvm-project/pull/123422

>From 7c849d7695247c9222cb8dd73b66aa35c328e650 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh at wsmoses.com>
Date: Fri, 17 Jan 2025 17:29:34 -0600
Subject: [PATCH 1/2] [MLIR][Math][GPU] Add lowering of absi and fpowi to
 libdevice

---
 .../GPUCommon/OpToFuncCallLowering.h          | 188 +++++++++++++-----
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  19 ++
 .../Conversion/GPUToNVVM/gpu-to-nvvm.mlir     |  24 +++
 3 files changed, 181 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 46fd182346b3b7..bbfcdaf91205ca 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -16,37 +16,11 @@
 
 namespace mlir {
 
-/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` or `f16Func` depending on the element type and the
-/// fastMathFlag of that Op. The function declaration is added in case it was
-/// not added before.
-///
-/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
-/// value is first casted to f32, the function called and then the result casted
-/// back.
-///
-/// Example with NVVM:
-///   %exp_f32 = math.exp %arg_f32 : f32
-///
-/// will be transformed into
-///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
-///
-/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
-/// to the approximate calculation function.
-///
-/// Also example with NVVM:
-///   %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
-///
-/// will be transformed into
-///   llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
-template <typename SourceOp>
-struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
+template <typename SourceOp, typename DerivedTy>
+struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
 public:
-  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
-                                StringRef f32Func, StringRef f64Func,
-                                StringRef f32ApproxFunc, StringRef f16Func)
-      : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
+  explicit OpToFuncCallLoweringBase(const LLVMTypeConverter &lowering)
+      : ConvertOpToLLVMPattern<SourceOp>(lowering) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -72,13 +46,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 
     SmallVector<Value, 1> castedOperands;
     for (Value operand : adaptor.getOperands())
-      castedOperands.push_back(maybeCast(operand, rewriter));
+      castedOperands.push_back(
+          ((const DerivedTy *)this)->maybeCast(operand, rewriter));
 
     Type resultType = castedOperands.front().getType();
     Type funcType = getFunctionType(resultType, castedOperands);
     StringRef funcName =
-        getFunctionName(cast<LLVM::LLVMFunctionType>(funcType).getReturnType(),
-                        op.getFastmath());
+        ((const DerivedTy *)this)
+            ->getFunctionName(
+                cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
     if (funcName.empty())
       return failure();
 
@@ -99,6 +75,61 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 
 private:
+  Type getFunctionType(Type resultType, ValueRange operands) const {
+    SmallVector<Type> operandTypes(operands.getTypes());
+    return LLVM::LLVMFunctionType::get(resultType, operandTypes);
+  }
+
+  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
+                                     Operation *op) const {
+    using LLVM::LLVMFuncOp;
+
+    auto funcAttr = StringAttr::get(op->getContext(), funcName);
+    Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
+    if (funcOp)
+      return cast<LLVMFuncOp>(*funcOp);
+
+    mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
+    return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
+  }
+};
+
+/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func` or
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the
+/// fastMathFlag of that Op. The function declaration is added in case it was
+/// not added before.
+///
+/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
+/// value is first casted to f32, the function called and then the result casted
+/// back.
+///
+/// Example with NVVM:
+///   %exp_f32 = math.exp %arg_f32 : f32
+///
+/// will be transformed into
+///   llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
+///
+/// If the fastMathFlag attribute of SourceOp is `afn` or `fast`, this Op lowers
+/// to the approximate calculation function.
+///
+/// Also example with NVVM:
+///   %exp_f32 = math.exp %arg_f32 fastmath<afn> : f32
+///
+/// will be transformed into
+///   llvm.call @__nv_fast_expf(%arg_f32) : (f32) -> f32
+template <typename SourceOp>
+struct OpToFuncCallLowering
+    : public OpToFuncCallLoweringBase<SourceOp,
+                                      OpToFuncCallLowering<SourceOp>> {
+public:
+  explicit OpToFuncCallLowering(const LLVMTypeConverter &lowering,
+                                StringRef f32Func, StringRef f64Func,
+                                StringRef f32ApproxFunc, StringRef f16Func)
+      : OpToFuncCallLoweringBase<SourceOp, OpToFuncCallLowering<SourceOp>>(
+            lowering),
+        f32Func(f32Func), f64Func(f64Func), f32ApproxFunc(f32ApproxFunc),
+        f16Func(f16Func) {}
+
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     Type type = operand.getType();
     if (!isa<Float16Type, BFloat16Type>(type))
@@ -112,12 +143,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
   }
 
-  Type getFunctionType(Type resultType, ValueRange operands) const {
-    SmallVector<Type> operandTypes(operands.getTypes());
-    return LLVM::LLVMFunctionType::get(resultType, operandTypes);
-  }
-
-  StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+  StringRef getFunctionName(Type type, SourceOp op) const {
+    arith::FastMathFlags flag = op.getFastmath();
     if (isa<Float16Type>(type))
       return f16Func;
     if (isa<Float32Type>(type)) {
@@ -132,23 +159,84 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     return "";
   }
 
-  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
-                                     Operation *op) const {
-    using LLVM::LLVMFuncOp;
+  const std::string f32Func;
+  const std::string f64Func;
+  const std::string f32ApproxFunc;
+  const std::string f16Func;
+};
 
-    auto funcAttr = StringAttr::get(op->getContext(), funcName);
-    Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
-    if (funcOp)
-      return cast<LLVMFuncOp>(*funcOp);
+/// Rewriting that replace SourceOp with a CallOp to `i32Func`
+/// The function declaration is added in case it was not added before.
+/// This assumes that all types integral.
+///
+/// Example with NVVM:
+///   %abs_i32 = math.iabs %arg_i32 : i32
+///
+/// will be transformed into
+///   llvm.call @__nv_abs(%arg_i32) : (i32) -> i32
+///
+template <typename SourceOp>
+struct IntOpToFuncCallLowering
+    : public OpToFuncCallLoweringBase<SourceOp,
+                                      IntOpToFuncCallLowering<SourceOp>> {
+public:
+  explicit IntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
+                                   StringRef i32Func)
+      : OpToFuncCallLoweringBase<SourceOp, IntOpToFuncCallLowering<SourceOp>>(
+            lowering),
+        i32Func(i32Func) {}
 
-    mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
-    return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
+  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+    return operand;
+  }
+
+  StringRef getFunctionName(Type type, SourceOp op) const {
+    IntegerType itype = dyn_cast<IntegerType>(type);
+    if (!itype || itype.getWidth() != 32)
+      return "";
+    return i32Func;
+  }
+
+  const std::string i32Func;
+};
+
+/// Rewriting that replaces SourceOp with a CallOp to `f32Func` or `f64Func`,
+/// depending on the type of the result. This assumes that the first argument is
+/// a floating type and the second argument is an integer type.
+///
+/// Example with NVVM:
+///   %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
+///
+/// will be transformed into
+///   llvm.call @__nv_powf(%arg_f32, %arg_i32) : (f32, i32) -> f32
+///
+template <typename SourceOp>
+struct FloatIntOpToFuncCallLowering
+    : public OpToFuncCallLoweringBase<SourceOp,
+                                      FloatIntOpToFuncCallLowering<SourceOp>> {
+public:
+  explicit FloatIntOpToFuncCallLowering(const LLVMTypeConverter &lowering,
+                                        StringRef f32Func, StringRef f64Func)
+      : OpToFuncCallLoweringBase<SourceOp,
+                                 FloatIntOpToFuncCallLowering<SourceOp>>(
+            lowering),
+        f32Func(f32Func), f64Func(f64Func) {}
+
+  Value maybeCast(Value operand, PatternRewriter &rewriter) const {
+    return operand;
+  }
+
+  StringRef getFunctionName(Type type, SourceOp op) const {
+    if (isa<Float32Type>(type)) {
+      return f32Func;
+    }
+    if (isa<Float64Type>(type))
+      return f64Func;
+    return "";
   }
 
   const std::string f32Func;
   const std::string f64Func;
-  const std::string f32ApproxFunc;
-  const std::string f16Func;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 2768929f460e2e..1971de30898fb7 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -446,6 +446,22 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
                                            f32ApproxFunc, f16Func);
 }
 
+template <typename OpTy>
+static void populateIntOpPatterns(const LLVMTypeConverter &converter,
+                                  RewritePatternSet &patterns,
+                                  StringRef i32Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<IntOpToFuncCallLowering<OpTy>>(converter, i32Func);
+}
+
+template <typename OpTy>
+static void populateFloatIntOpPatterns(const LLVMTypeConverter &converter,
+                                       RewritePatternSet &patterns,
+                                       StringRef f32Func, StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<FloatIntOpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<GPUSubgroupReduceOpLowering>(converter);
@@ -509,6 +525,7 @@ void mlir::populateGpuToNVVMConversionPatterns(
 
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__nv_fmodf",
                                     "__nv_fmod");
+  populateIntOpPatterns<math::AbsIOp>(converter, patterns, "__nv_abs");
   populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
                                    "__nv_fabs");
   populateOpPatterns<math::AcosOp>(converter, patterns, "__nv_acosf",
@@ -555,6 +572,8 @@ void mlir::populateGpuToNVVMConversionPatterns(
                                    "__nv_log2", "__nv_fast_log2f");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__nv_powf", "__nv_pow",
                                    "__nv_fast_powf");
+  populateFloatIntOpPatterns<math::FPowIOp>(converter, patterns, "__nv_powif",
+                                            "__nv_powi");
   populateOpPatterns<math::RoundOp>(converter, patterns, "__nv_roundf",
                                     "__nv_round");
   populateOpPatterns<math::RoundEvenOp>(converter, patterns, "__nv_rintf",
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index f52dd6c0d0ce30..94c0f9e34c29ce 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -1033,3 +1033,27 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+
+gpu.module @test_module_52 {
+  // CHECK: llvm.func @__nv_abs(i32) -> i32
+  // CHECK-LABEL: func @gpu_abs
+  func.func @gpu_abs(%arg_i32 : i32) -> (i32) {
+    %result32 = math.absi %arg_i32 : i32
+    // CHECK: llvm.call @__nv_abs(%{{.*}}) : (i32) -> i32
+    func.return %result32 : i32
+  }
+}
+
+gpu.module @test_module_53 {
+  // CHECK: llvm.func @__nv_powif(f32, i32) -> f32
+  // CHECK: llvm.func @__nv_powi(f64, i32) -> f64
+  // CHECK-LABEL: func @gpu_powi
+  func.func @gpu_powi(%arg_f32 : f32, %arg_f64 : f64, %arg_i32 : i32) -> (f32, f64) {
+    %result32 = math.fpowi %arg_f32, %arg_i32 : f32, i32
+    // CHECK: llvm.call @__nv_powif(%{{.*}}, %{{.*}}) : (f32, i32) -> f32
+    %result64 = math.fpowi %arg_f64, %arg_i32 : f64, i32
+    // CHECK: llvm.call @__nv_powi(%{{.*}}, %{{.*}}) : (f64, i32) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}

>From d6b89f00447add33943da404fa198d80103c8c47 Mon Sep 17 00:00:00 2001
From: William Moses <gh at wsmoses.com>
Date: Sun, 19 Jan 2025 14:19:06 -0600
Subject: [PATCH 2/2] Apply suggestions from code review

Co-authored-by: Oleksandr "Alex" Zinenko <ftynse at gmail.com>
---
 mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index bbfcdaf91205ca..0c1755d593339c 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -47,12 +47,12 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
     SmallVector<Value, 1> castedOperands;
     for (Value operand : adaptor.getOperands())
       castedOperands.push_back(
-          ((const DerivedTy *)this)->maybeCast(operand, rewriter));
+          static_cast<const DerivedTy *>(this)->maybeCast(operand, rewriter));
 
     Type resultType = castedOperands.front().getType();
     Type funcType = getFunctionType(resultType, castedOperands);
     StringRef funcName =
-        ((const DerivedTy *)this)
+        static_cast<const DerivedTy *>(this)
             ->getFunctionName(
                 cast<LLVM::LLVMFunctionType>(funcType).getReturnType(), op);
     if (funcName.empty())
@@ -85,11 +85,13 @@ struct OpToFuncCallLoweringBase : public ConvertOpToLLVMPattern<SourceOp> {
     using LLVM::LLVMFuncOp;
 
     auto funcAttr = StringAttr::get(op->getContext(), funcName);
-    Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
+    auto funcOp = SymbolTable::lookupNearestSymbolFrom<LLVMFuncOp>(op, funcAttr);
     if (funcOp)
-      return cast<LLVMFuncOp>(*funcOp);
+      return funcOp;
 
-    mlir::OpBuilder b(op->getParentOfType<FunctionOpInterface>());
+    auto parentFunc = op->getParentOfType<FunctionOpInterface>();
+    assert(parentFunc && "expected there to be a parent function");
+    OpBuilder b(parentFunc);
     return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
   }
 };



More information about the Mlir-commits mailing list