[Mlir-commits] [mlir] [MLIR][ROCDL] Convert `math::fpowi` to ROCDL call (PR #122640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 12 22:14:46 PST 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/122640

>From 9c9513713075f810590eb87a813bcc2b507b3515 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 12 Jan 2025 15:26:25 +0000
Subject: [PATCH 1/3] [MLIR][ROCDL] Convert `math::fpowi` to ROCDL call

---
 .../GPUCommon/OpToFuncCallLowering.h           |  4 +++-
 .../lib/Conversion/MathToROCDL/MathToROCDL.cpp |  4 +++-
 .../Conversion/MathToROCDL/math-to-rocdl.mlir  | 18 ++++++++++++++++++
 3 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 3b94abd88f9ed2..caa3148dedff57 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/IR/Builders.h"
 
 namespace mlir {
@@ -58,7 +59,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         "expected single result op");
 
     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
+                                  SourceOp>::value ||
+                      std::is_same_v<SourceOp, math::FPowIOp>,
                   "expected op with same operand and result types");
 
     if (!op->template getParentOfType<FunctionOpInterface>()) {
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index c17bfe4f71a98d..627bed011826a2 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,7 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::FmaOp
   // Handled by mathToLLVM: math::LogOp (32-bit only)
   // FIXME: math::IPowIOp
-  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +114,8 @@ void mlir::populateMathToROCDLConversionPatterns(
                                   "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
                                   "__ocml_erf_f64", "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
+                                    "__ocml_pown_f64", "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e0ea18d41f66da..e4b2f01d6544ab 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -484,6 +484,24 @@ module @test_module {
 
 // -----
 
+module @test_module {
+  // CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
+  // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
+  // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
+  // CHECK-LABEL: func @math_fpowi
+  func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+    // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
+    %0 = math.fpowi %arg0, %arg3 : f16, i32
+    // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+    %1 = math.fpowi %arg1, %arg3 : f32, i32
+    // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+    %2 = math.fpowi %arg2, %arg3 : f64, i32
+    return %0, %1, %2 : f16, f32, f64
+  }
+}
+
+// -----
+
 // Math operation not inside function
 // Ensure it not crash
 

>From aab089e234e0ce525826828234cc9bab36638544 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 13 Jan 2025 04:18:04 +0000
Subject: [PATCH 2/3] Remove static_assert and use a runtime assert.

---
 mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index caa3148dedff57..9d9d38cc066aa2 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -58,10 +58,9 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value ||
-                      std::is_same_v<SourceOp, math::FPowIOp>,
-                  "expected op with same operand and result types");
+    if (op->getResultTypes().front() != op->getOperand(0).getType())
+      return rewriter.notifyMatchFailure(
+          op, "expected op with same operand and result types");
 
     if (!op->template getParentOfType<FunctionOpInterface>()) {
       return rewriter.notifyMatchFailure(

>From 8c180db936929005215e5faa255bfbac81503291 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Mon, 13 Jan 2025 06:14:28 +0000
Subject: [PATCH 3/3] remove the comment

---
 mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 627bed011826a2..838eef30a938fe 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,6 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::FmaOp
   // Handled by mathToLLVM: math::LogOp (32-bit only)
   // FIXME: math::IPowIOp
-  // Handled by mathToLLVM: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp



More information about the Mlir-commits mailing list