[flang-commits] [flang] [flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls (PR #99517)
Jan Leyonberg via flang-commits
flang-commits at lists.llvm.org
Thu Jul 18 08:42:07 PDT 2024
https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/99517
This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.
>From 665a76feb553c855e82be44629b86d521b4cdc2a Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Thu, 18 Jul 2024 11:06:51 -0400
Subject: [PATCH] [flang][AMDGPU] Convert math ops to AMD GPU library calls
instead of libm calls
This patch invokes a pass when compiling for an AMDGPU target to lower math
operations to AMD GPU library calls library calls instead of libm calls.
---
flang/lib/Optimizer/CodeGen/CMakeLists.txt | 1 +
flang/lib/Optimizer/CodeGen/CodeGen.cpp | 12 +-
flang/test/Lower/OpenMP/math-amdgpu.f90 | 184 +++++++++++++++++++++
3 files changed, 196 insertions(+), 1 deletion(-)
create mode 100644 flang/test/Lower/OpenMP/math-amdgpu.f90
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 650448eee1099..646621cb01c15 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
MLIRMathToFuncs
MLIRMathToLLVM
MLIRMathToLibm
+ MLIRMathToROCDL
MLIROpenMPToLLVM
MLIROpenACCDialect
MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f9ea92a843b23..02992857dde06 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -34,6 +34,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3610,6 +3611,14 @@ class FIRToLLVMLowering
// as passes here.
mlir::OpPassManager mathConvertionPM("builtin.module");
+ bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+ // If compiling for AMD target some math operations must be lowered to AMD
+ // GPU library calls, the rest can be converted to LLVM intrinsics, which
+ // is handled in the mathToLLVM conversion. The lowering to libm calls is
+ // not needed since all math operations are handled this way.
+ if (isAMDGCN)
+ mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+
// Convert math::FPowI operations to inline implementation
// only if the exponent's width is greater than 32, otherwise,
// it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3649,7 +3658,8 @@ class FIRToLLVMLowering
pattern);
// Math operations that have not been converted yet must be converted
// to Libm.
- mlir::populateMathToLibmConversionPatterns(pattern);
+ if (!isAMDGCN)
+ mlir::populateMathToLibmConversionPatterns(pattern);
mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90
new file mode 100644
index 0000000000000..b455b42d3ed34
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-amdgpu.f90
@@ -0,0 +1,184 @@
+!REQUIRES: amdgpu-registered-target
+!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
+ y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
+ y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_sin_f32(float {{.*}})
+ y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_sin_f64(double {{.*}})
+ y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_fabs_f32(float {{.*}})
+ y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_fabs_f64(double {{.*}})
+ y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_atan_f32(float {{.*}})
+ y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_atan_f64(double {{.*}})
+ y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
+ y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
+ y = atan2(x ,x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_cos_f32(float {{.*}})
+ y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_cos_f64(double {{.*}})
+ y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_erf_f32(float {{.*}})
+ y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_erf_f64(double {{.*}})
+ y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_exp_f32(float {{.*}})
+ y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_exp_f64(double {{.*}})
+ y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_log_f32(float {{.*}})
+ y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_log_f64(double {{.*}})
+ y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_log10_f32(float {{.*}})
+ y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_log10_f64(double {{.*}})
+ y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_sqrt_f32(float {{.*}})
+ y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_sqrt_f64(double {{.*}})
+ y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_tan_f32(float {{.*}})
+ y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_tan_f64(double {{.*}})
+ y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+ real :: x, y
+!CHECK: call float @__ocml_tanh_f32(float {{.*}})
+ y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+ real(8) :: x, y
+!CHECK: call double @__ocml_tanh_f64(double {{.*}})
+ y = tanh(x)
+end subroutine omp_tanh_f64
More information about the flang-commits
mailing list