[Mlir-commits] [mlir] [MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (PR #98653)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 12 08:20:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jan Leyonberg (jsjodin)

<details>
<summary>Changes</summary>

This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.

---

Patch is 31.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98653.diff


9 Files Affected:

- (added) mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h (+26) 
- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+18) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt (+1) 
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+2-44) 
- (added) mlir/lib/Conversion/MathToROCDL/CMakeLists.txt (+23) 
- (added) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+146) 
- (added) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+435) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..64835b1b660b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,24 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Math dialect to ROCDL library calls";
+  let description = [{
+    This pass converts supported Math ops to ROCDL library calls.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "func::FuncDialect",
+    "math::MathDialect",
+    "ROCDL::ROCDLDialect",
+    "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
   MLIRArithToLLVM
   MLIRArithTransforms
   MLIRMathToLLVM
+  MLIRMathToROCDL
   MLIRAMDGPUToROCDL
   MLIRFuncToLLVM
   MLIRGPUDialect
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
 
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+  populateMathToROCDLConversionPatterns(converter, patterns);
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+  MathToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                                 RewritePatternSet &patterns) {
+  // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::CopySignOp
+  // Handled by mathToLLVM: math::CountLeadingZerosOp
+  // Handled by mathToLLVM: math::CountTrailingZerosOp
+  // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::FmaOp
+  // FIXME: math::IPowIOp
+  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::RoundEvenOp
+  // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::TruncOp
+  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+                                   "__ocml_fabs_f64");
+  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+                                   "__ocml_acos_f64");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+                                    "__ocml_acosh_f64");
+  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+                                   "__ocml_asin_f64");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+                                    "__ocml_asinh_f64");
+  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+                                   "__ocml_atan_f64");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+                                    "__ocml_atanh_f64");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+                                    "__ocml_atan2_f64");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+                                   "__ocml_cbrt_f64");
+  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+                                   "__ocml_ceil_f64");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+                                  "__ocml_cos_f64");
+  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+                                   "__ocml_cosh_f64");
+  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+                                   "__ocml_sinh_f64");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+                                  "__ocml_exp_f64");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+                                   "__ocml_exp2_f64");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+                                    "__ocml_expm1_f64");
+  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+                                    "__ocml_floor_f64");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+                                  "__ocml_log_f64");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+                                    "__ocml_log10_f64");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+                                    "__ocml_log1p_f64");
+  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+                                   "__ocml_log2_f64");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+                                   "__ocml_pow_f64");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+                                    "__ocml_rsqrt_f64");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+                                  "__ocml_sin_f64");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+                                   "__ocml_sqrt_f64");
+  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+                                   "__ocml_tanh_f64");
+  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+                                  "__ocml_tan_f64");
+  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+                                  "__ocml_erf_f64");
+  // Single arith pattern that needs a ROCDL call, probably not
+  // worth creating a separate pass for it.
+  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+                                    "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+    : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+  ConvertMathToROCDLPass() = default;
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+  populateMathToROCDLConversionPatterns(converter, patterns);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+                      LLVM::SqrtOp>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @arith_remf
+  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = arith.remf %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = arith.remf %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+  // CHECK-LABEL: func @math_absf
+  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.absf %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.absf %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acos
+  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acosh
+  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asin
+  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asinh
+  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atan
+  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.atan %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atanh
+  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/98653


More information about the Mlir-commits mailing list