[Mlir-commits] [mlir] [MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (PR #98653)
Jan Leyonberg
llvmlistbot at llvm.org
Fri Jul 12 08:20:21 PDT 2024
https://github.com/jsjodin created https://github.com/llvm/llvm-project/pull/98653
This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.
>From 5372b0fd207b83345e29075af4de7ec4a43c8329 Mon Sep 17 00:00:00 2001
From: Jan Leyonberg <jan_sjodin at yahoo.com>
Date: Tue, 18 Jun 2024 11:05:21 -0400
Subject: [PATCH] [MLIR][ROCDL] Refactor conversion of math operations to ROCDL
calls to a separaate pass
This patch refactors the conversion of math operations to ROCDL library
calls. This pass will also be used in flang to lower Fortran intrinsics/math
functions for OpenMP target offloading codgen.
---
.../mlir/Conversion/MathToROCDL/MathToROCDL.h | 26 ++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 18 +
mlir/lib/Conversion/CMakeLists.txt | 1 +
mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt | 1 +
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 46 +-
.../lib/Conversion/MathToROCDL/CMakeLists.txt | 23 +
.../Conversion/MathToROCDL/MathToROCDL.cpp | 146 ++++++
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 435 ++++++++++++++++++
9 files changed, 653 insertions(+), 44 deletions(-)
create mode 100644 mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
create mode 100644 mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
create mode 100644 mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
create mode 100644 mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..64835b1b660b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,24 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Math dialect to ROCDL library calls";
+ let description = [{
+ This pass converts supported Math ops to ROCDL library calls.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "math::MathDialect",
+ "ROCDL::ROCDLDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
MLIRArithToLLVM
MLIRArithTransforms
MLIRMathToLLVM
+ MLIRMathToROCDL
MLIRAMDGPUToROCDL
MLIRFuncToLLVM
MLIRGPUDialect
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
- populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
- populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64");
- populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64");
- populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64");
- populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
- populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
- "__ocml_exp_f64");
- populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64");
- populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
- "__ocml_log_f64");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64");
- populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64");
- populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64");
- populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64");
- populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
- populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
- populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64");
- populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64");
+ populateMathToROCDLConversionPatterns(converter, patterns);
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+ MathToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ // Handled by mathToLLVM: math::AbsIOp
+ // Handled by mathToLLVM: math::CopySignOp
+ // Handled by mathToLLVM: math::CountLeadingZerosOp
+ // Handled by mathToLLVM: math::CountTrailingZerosOp
+ // Handled by mathToLLVM: math::CgPopOp
+ // Handled by mathToLLVM: math::FmaOp
+ // FIXME: math::IPowIOp
+ // FIXME: math::FPowIOp
+ // Handled by mathToLLVM: math::RoundEvenOp
+ // Handled by mathToLLVM: math::RoundOp
+ // Handled by mathToLLVM: math::TruncOp
+ populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+ "__ocml_fabs_f64");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+ "__ocml_acos_f64");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+ "__ocml_acosh_f64");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+ "__ocml_asin_f64");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+ "__ocml_asinh_f64");
+ populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+ "__ocml_atan_f64");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+ "__ocml_atanh_f64");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+ "__ocml_atan2_f64");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+ "__ocml_cbrt_f64");
+ populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+ "__ocml_ceil_f64");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+ "__ocml_cos_f64");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+ "__ocml_cosh_f64");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+ "__ocml_sinh_f64");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+ "__ocml_exp_f64");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+ "__ocml_exp2_f64");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+ "__ocml_expm1_f64");
+ populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+ "__ocml_floor_f64");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+ "__ocml_log_f64");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+ "__ocml_log10_f64");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+ "__ocml_log1p_f64");
+ populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+ "__ocml_log2_f64");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+ "__ocml_pow_f64");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+ "__ocml_rsqrt_f64");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+ "__ocml_sin_f64");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+ "__ocml_sqrt_f64");
+ populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+ "__ocml_tanh_f64");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+ "__ocml_tan_f64");
+ populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+ "__ocml_erf_f64");
+ // Single arith pattern that needs a ROCDL call, probably not
+ // worth creating a separate pass for it.
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+ "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+ : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ ConvertMathToROCDLPass() = default;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+ populateMathToROCDLConversionPatterns(converter, patterns);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+ LLVM::SqrtOp>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @arith_remf
+ func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = arith.remf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = arith.remf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+ // CHECK-LABEL: func @math_absf
+ func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.absf %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.absf %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acos
+ func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acosh
+ func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asin
+ func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asinh
+ func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atan
+ func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atanh
+ func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atanh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atanh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @math_atan2
+ func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan2 %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = math.atan2 %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
+ // CHECK-LABEL: func @math_cbrt
+ func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.cbrt %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cbrt %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
+ // CHECK-LABEL: func @math_ceil
+ func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.ceil %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.ceil %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
+ // CHECK-LABEL: func @math_cos
+ func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.cos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_cosh
+ func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.cosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.cosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_sinh
+ func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.sinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
+ // CHECK-LABEL: func @math_exp
+ func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.exp %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.exp %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
+ // CHECK-LABEL: func @math_exp2
+ func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.exp2 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.exp2 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
+ // CHECK-LABEL: func @math_expm1
+ func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.expm1 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.expm1 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
+ // CHECK-LABEL: func @math_floor
+ func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.floor %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.floor %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
+ // CHECK-LABEL: func @math_log
+ func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.log %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.log %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
+ // CHECK-LABEL: func @math_log10
+ func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.log10 %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.log10 %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
+ // CHECK-LABEL: func @math_log1p
+ func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.log1p %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.log1p %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @math_powf
+ func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.powf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = math.powf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
+ // CHECK-LABEL: func @math_rsqrt
+ func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.rsqrt %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.rsqrt %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+ // CHECK-LABEL: func @math_sin
+ func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.sin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
+ // CHECK-LABEL: func @math_sqrt
+ func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.sqrt %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sqrt %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_tanh
+ func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.tanh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.tanh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
+ // CHECK-LABEL: func @math_tan
+ func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.tan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.tan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
+ // CHECK-LABEL: func @math_erf
+ func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.erf %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.erf %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @arith_remf
+ func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = arith.remf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = arith.remf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
More information about the Mlir-commits
mailing list