[flang-commits] [flang] 4290e34 - [flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls (#99517)

via flang-commits flang-commits at lists.llvm.org
Tue Sep 10 06:48:58 PDT 2024


Author: Jan Leyonberg
Date: 2024-09-10T09:48:55-04:00
New Revision: 4290e34ebdddaa62210745c84ac3e6703cadfa34

URL: https://github.com/llvm/llvm-project/commit/4290e34ebdddaa62210745c84ac3e6703cadfa34
DIFF: https://github.com/llvm/llvm-project/commit/4290e34ebdddaa62210745c84ac3e6703cadfa34.diff

LOG: [flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls (#99517)

This patch invokes a pass when compiling for an AMDGPU target to lower
math operations to AMD GPU library calls library calls instead of libm
calls.

Added: 
    flang/test/Lower/OpenMP/math-amdgpu.f90

Modified: 
    flang/lib/Optimizer/CodeGen/CMakeLists.txt
    flang/lib/Optimizer/CodeGen/CodeGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 650448eee10993..646621cb01c157 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
+  MLIRMathToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation

diff  --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index ac521ae95df39c..88293bcf36a780 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -36,6 +36,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3671,6 +3672,14 @@ class FIRToLLVMLowering
     // as passes here.
     mlir::OpPassManager mathConvertionPM("builtin.module");
 
+    bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+    // If compiling for AMD target some math operations must be lowered to AMD
+    // GPU library calls, the rest can be converted to LLVM intrinsics, which
+    // is handled in the mathToLLVM conversion. The lowering to libm calls is
+    // not needed since all math operations are handled this way.
+    if (isAMDGCN)
+      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
     // it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3710,7 +3719,8 @@ class FIRToLLVMLowering
                                                           pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
-    mlir::populateMathToLibmConversionPatterns(pattern);
+    if (!isAMDGCN)
+      mlir::populateMathToLibmConversionPatterns(pattern);
     mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
     mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
 

diff  --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90
new file mode 100644
index 00000000000000..116768ba9412a5
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-amdgpu.f90
@@ -0,0 +1,184 @@
+!REQUIRES: amdgpu-registered-target
+!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
+  y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
+  y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sin_f32(float {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sin_f64(double {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call contract float @llvm.fabs.f32(float {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call contract double @llvm.fabs.f64(double {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan_f32(float {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan_f64(double {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
+  y = atan2(x ,x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_cos_f32(float {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_cos_f64(double {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_erf_f32(float {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_erf_f64(double {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call contract float @llvm.exp.f32(float {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_exp_f64(double {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call contract float @llvm.log.f32(float {{.*}})
+  y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log_f64(double {{.*}})
+  y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log10_f32(float {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log10_f64(double {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call contract float @llvm.sqrt.f32(float {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call contract double @llvm.sqrt.f64(double {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tan_f32(float {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tan_f64(double {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tanh_f32(float {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tanh_f64(double {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f64


        


More information about the flang-commits mailing list