[flang-commits] [flang] [flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls (PR #99517)

Jan Leyonberg via flang-commits flang-commits at lists.llvm.org
Thu Jul 18 08:42:07 PDT 2024


https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/99517

This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.

>From 665a76feb553c855e82be44629b86d521b4cdc2a Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Thu, 18 Jul 2024 11:06:51 -0400
Subject: [PATCH] [flang][AMDGPU] Convert math ops to AMD GPU library calls
 instead of libm calls

This patch invokes a pass when compiling for an AMDGPU target to lower math
operations to AMD GPU library calls library calls instead of libm calls.
---
 flang/lib/Optimizer/CodeGen/CMakeLists.txt |   1 +
 flang/lib/Optimizer/CodeGen/CodeGen.cpp    |  12 +-
 flang/test/Lower/OpenMP/math-amdgpu.f90    | 184 +++++++++++++++++++++
 3 files changed, 196 insertions(+), 1 deletion(-)
 create mode 100644 flang/test/Lower/OpenMP/math-amdgpu.f90

diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 650448eee1099..646621cb01c15 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
+  MLIRMathToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f9ea92a843b23..02992857dde06 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3610,6 +3611,14 @@ class FIRToLLVMLowering
     // as passes here.
     mlir::OpPassManager mathConvertionPM("builtin.module");
 
+    bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+    // If compiling for AMD target some math operations must be lowered to AMD
+    // GPU library calls, the rest can be converted to LLVM intrinsics, which
+    // is handled in the mathToLLVM conversion. The lowering to libm calls is
+    // not needed since all math operations are handled this way.
+    if (isAMDGCN)
+      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
     // it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3649,7 +3658,8 @@ class FIRToLLVMLowering
                                                           pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
-    mlir::populateMathToLibmConversionPatterns(pattern);
+    if (!isAMDGCN)
+      mlir::populateMathToLibmConversionPatterns(pattern);
     mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
     mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
 
diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90
new file mode 100644
index 0000000000000..b455b42d3ed34
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-amdgpu.f90
@@ -0,0 +1,184 @@
+!REQUIRES: amdgpu-registered-target
+!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
+  y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
+  y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sin_f32(float {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sin_f64(double {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_fabs_f32(float {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_fabs_f64(double {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan_f32(float {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan_f64(double {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
+  y = atan2(x ,x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_cos_f32(float {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_cos_f64(double {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_erf_f32(float {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_erf_f64(double {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_exp_f32(float {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_exp_f64(double {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log_f32(float {{.*}})
+  y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log_f64(double {{.*}})
+  y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log10_f32(float {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log10_f64(double {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sqrt_f32(float {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sqrt_f64(double {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tan_f32(float {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tan_f64(double {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tanh_f32(float {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tanh_f64(double {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f64



More information about the flang-commits mailing list