[Mlir-commits] [mlir] [MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (PR #98653)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 12 08:20:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jan Leyonberg (jsjodin)
<details>
<summary>Changes</summary>
This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.
---
Patch is 31.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98653.diff
9 Files Affected:
- (added) mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h (+26)
- (modified) mlir/include/mlir/Conversion/Passes.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+18)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (modified) mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt (+1)
- (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+2-44)
- (added) mlir/lib/Conversion/MathToROCDL/CMakeLists.txt (+23)
- (added) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+146)
- (added) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+435)
``````````diff
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..64835b1b660b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,24 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Math dialect to ROCDL library calls";
+ let description = [{
+ This pass converts supported Math ops to ROCDL library calls.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "math::MathDialect",
+ "ROCDL::ROCDLDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
MLIRArithToLLVM
MLIRArithTransforms
MLIRMathToLLVM
+ MLIRMathToROCDL
MLIRAMDGPUToROCDL
MLIRFuncToLLVM
MLIRGPUDialect
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
- populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
- populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64");
- populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64");
- populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64");
- populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
- populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
- "__ocml_exp_f64");
- populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64");
- populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
- "__ocml_log_f64");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64");
- populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64");
- populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64");
- populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64");
- populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
- populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
- populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64");
- populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64");
+ populateMathToROCDLConversionPatterns(converter, patterns);
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+ MathToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ // Handled by mathToLLVM: math::AbsIOp
+ // Handled by mathToLLVM: math::CopySignOp
+ // Handled by mathToLLVM: math::CountLeadingZerosOp
+ // Handled by mathToLLVM: math::CountTrailingZerosOp
+ // Handled by mathToLLVM: math::CgPopOp
+ // Handled by mathToLLVM: math::FmaOp
+ // FIXME: math::IPowIOp
+ // FIXME: math::FPowIOp
+ // Handled by mathToLLVM: math::RoundEvenOp
+ // Handled by mathToLLVM: math::RoundOp
+ // Handled by mathToLLVM: math::TruncOp
+ populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+ "__ocml_fabs_f64");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+ "__ocml_acos_f64");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+ "__ocml_acosh_f64");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+ "__ocml_asin_f64");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+ "__ocml_asinh_f64");
+ populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+ "__ocml_atan_f64");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+ "__ocml_atanh_f64");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+ "__ocml_atan2_f64");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+ "__ocml_cbrt_f64");
+ populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+ "__ocml_ceil_f64");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+ "__ocml_cos_f64");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+ "__ocml_cosh_f64");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+ "__ocml_sinh_f64");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+ "__ocml_exp_f64");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+ "__ocml_exp2_f64");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+ "__ocml_expm1_f64");
+ populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+ "__ocml_floor_f64");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+ "__ocml_log_f64");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+ "__ocml_log10_f64");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+ "__ocml_log1p_f64");
+ populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+ "__ocml_log2_f64");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+ "__ocml_pow_f64");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+ "__ocml_rsqrt_f64");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+ "__ocml_sin_f64");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+ "__ocml_sqrt_f64");
+ populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+ "__ocml_tanh_f64");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+ "__ocml_tan_f64");
+ populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+ "__ocml_erf_f64");
+ // Single arith pattern that needs a ROCDL call, probably not
+ // worth creating a separate pass for it.
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+ "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+ : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ ConvertMathToROCDLPass() = default;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+ populateMathToROCDLConversionPatterns(converter, patterns);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+ LLVM::SqrtOp>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @arith_remf
+ func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = arith.remf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = arith.remf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+ // CHECK-LABEL: func @math_absf
+ func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.absf %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.absf %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acos
+ func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acosh
+ func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asin
+ func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asinh
+ func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atan
+ func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atanh
+ func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/98653
More information about the Mlir-commits
mailing list