[Mlir-commits] [mlir] 9f114af - [MLIR][ROCDL] Convert `math::fpowi` to ROCDL call (#122640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 20:31:29 PST 2025


Author: lialan
Date: 2025-01-13T20:31:25-08:00
New Revision: 9f114afe092483983a82a73c82704f11bb28bf8c

URL: https://github.com/llvm/llvm-project/commit/9f114afe092483983a82a73c82704f11bb28bf8c
DIFF: https://github.com/llvm/llvm-project/commit/9f114afe092483983a82a73c82704f11bb28bf8c.diff

LOG: [MLIR][ROCDL] Convert `math::fpowi` to ROCDL call (#122640)

* Have to relax static assert to allow reuse of existing template
patterns for conversion.

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
    mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 3b94abd88f9ed2..46fd182346b3b7 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -57,9 +57,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
 
-    static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
-                                  SourceOp>::value,
-                  "expected op with same operand and result types");
+    if constexpr (!std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
+                                   SourceOp>::value) {
+      assert(op->getNumOperands() > 0 &&
+             "expected op to take at least one operand");
+      assert(op->getResultTypes().front() == op->getOperand(0).getType() &&
+             "expected op with same operand and result types");
+    }
 
     if (!op->template getParentOfType<FunctionOpInterface>()) {
       return rewriter.notifyMatchFailure(

diff  --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index c17bfe4f71a98d..838eef30a938fe 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -57,7 +57,6 @@ void mlir::populateMathToROCDLConversionPatterns(
   // Handled by mathToLLVM: math::FmaOp
   // Handled by mathToLLVM: math::LogOp (32-bit only)
   // FIXME: math::IPowIOp
-  // FIXME: math::FPowIOp
   // Handled by mathToLLVM: math::RoundEvenOp
   // Handled by mathToLLVM: math::RoundOp
   // Handled by mathToLLVM: math::SqrtOp
@@ -114,6 +113,8 @@ void mlir::populateMathToROCDLConversionPatterns(
                                   "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
                                   "__ocml_erf_f64", "__ocml_erf_f16");
+  populateOpPatterns<math::FPowIOp>(converter, patterns, "__ocml_pown_f32",
+                                    "__ocml_pown_f64", "__ocml_pown_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",

diff  --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index e0ea18d41f66da..e4b2f01d6544ab 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -484,6 +484,24 @@ module @test_module {
 
 // -----
 
+module @test_module {
+  // CHECK: llvm.func @__ocml_pown_f16(f16, i32) -> f16
+  // CHECK: llvm.func @__ocml_pown_f32(f32, i32) -> f32
+  // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64
+  // CHECK-LABEL: func @math_fpowi
+  func.func @math_fpowi(%arg0: f16, %arg1: f32, %arg2: f64, %arg3: i32) -> (f16, f32, f64) {
+    // CHECK: llvm.call @__ocml_pown_f16(%{{.*}}) : (f16, i32) -> f16
+    %0 = math.fpowi %arg0, %arg3 : f16, i32
+    // CHECK: llvm.call @__ocml_pown_f32(%{{.*}}) : (f32, i32) -> f32
+    %1 = math.fpowi %arg1, %arg3 : f32, i32
+    // CHECK: llvm.call @__ocml_pown_f64(%{{.*}}) : (f64, i32) -> f64
+    %2 = math.fpowi %arg2, %arg3 : f64, i32
+    return %0, %1, %2 : f16, f32, f64
+  }
+}
+
+// -----
+
 // Math operation not inside function
 // Ensure it not crash
 


        


More information about the Mlir-commits mailing list