[Mlir-commits] [mlir] [MLIR][ROCDL] Convert `math::fpowi` to ROCDL call (PR #122640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 07:46:15 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: lialan (lialan)

<details>
<summary>Changes</summary>

* Have to relax static assert to allow reuse of existing template patterns for conversion.

---
Full diff: https://github.com/llvm/llvm-project/pull/122640.diff


3 Files Affected:

- (modified) mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h (+3-1) 
- (modified) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+3-1) 
- (modified) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+18) 


``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 3b94abd88f9ed2..caa3148dedff57 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
@@ -58,7 +59,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         "expected single result op");
 
     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
+                                  SourceOp>::value ||
+                      std::is_same_v<SourceOp, math::FPowIOp>,
                   "expected op with same operand and result types");
 
     if (!op->template getParentOfType<FunctionOpInterface>()) {
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index c17bfe4f71a98d..627bed011826a2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,7 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::FmaOp
   // Handled by mathToLLVM: math::LogOp (32-bit only)
   // FIXME: math::IPowIOp
-  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +114,8 @@ void mlir::populateMathToROCDLConversionPatterns(
                                   "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
                                   "__ocml_erf_f64", "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
+                                    "__ocml_pown_f64", "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e0ea18d41f66da..e4b2f01d6544ab 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -484,6 +484,24 @@ module @test_module {
 
 // -----
 
+module @test_module {
+  // CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
+  // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
+  // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
+  // CHECK-LABEL: func @math_fpowi
+  func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+    // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
+    %0 = math.fpowi %arg0, %arg3 : f16, i32
+    // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+    %1 = math.fpowi %arg1, %arg3 : f32, i32
+    // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+    %2 = math.fpowi %arg2, %arg3 : f64, i32
+    return %0, %1, %2 : f16, f32, f64
+  }
+}
+
+// -----
+
 // Math operation not inside function
 // Ensure it not crash
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/122640


More information about the Mlir-commits mailing list