[Mlir-commits] [mlir] 3fae555 - [MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (#98653)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 17 06:33:07 PDT 2024


Author: Jan Leyonberg
Date: 2024-07-17T09:33:04-04:00
New Revision: 3fae5551de72756c3bb9fb2e5a29c95f02cbbd6b

URL: https://github.com/llvm/llvm-project/commit/3fae5551de72756c3bb9fb2e5a29c95f02cbbd6b
DIFF: https://github.com/llvm/llvm-project/commit/3fae5551de72756c3bb9fb2e5a29c95f02cbbd6b.diff

LOG: [MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (#98653)

This patch refactors the conversion of math operations to ROCDL library
calls. This pass will also be used in flang to lower Fortran
intrinsics/math functions for OpenMP target offloading codgen.

Added: 
    mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
    mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
    mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
    mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Modified: 
    mlir/include/mlir/Conversion/Passes.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/CMakeLists.txt
    mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..54b94bbfb93d1 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,23 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Math dialect to ROCDL library calls";
+  let description = [{
+    This pass converts supported Math ops to ROCDL library calls.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "func::FuncDialect",
+    "ROCDL::ROCDLDialect",
+    "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToSPIRV
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)

diff  --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
   MLIRArithToLLVM
   MLIRArithTransforms
   MLIRMathToLLVM
+  MLIRMathToROCDL
   MLIRAMDGPUToROCDL
   MLIRFuncToLLVM
   MLIRGPUDialect

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
 
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+  populateMathToROCDLConversionPatterns(converter, patterns);
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

diff  --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+  MathToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )

diff  --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                                 RewritePatternSet &patterns) {
+  // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::CopySignOp
+  // Handled by mathToLLVM: math::CountLeadingZerosOp
+  // Handled by mathToLLVM: math::CountTrailingZerosOp
+  // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::FmaOp
+  // FIXME: math::IPowIOp
+  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::RoundEvenOp
+  // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::TruncOp
+  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+                                   "__ocml_fabs_f64");
+  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+                                   "__ocml_acos_f64");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+                                    "__ocml_acosh_f64");
+  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+                                   "__ocml_asin_f64");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+                                    "__ocml_asinh_f64");
+  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+                                   "__ocml_atan_f64");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+                                    "__ocml_atanh_f64");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+                                    "__ocml_atan2_f64");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+                                   "__ocml_cbrt_f64");
+  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+                                   "__ocml_ceil_f64");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+                                  "__ocml_cos_f64");
+  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+                                   "__ocml_cosh_f64");
+  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+                                   "__ocml_sinh_f64");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+                                  "__ocml_exp_f64");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+                                   "__ocml_exp2_f64");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+                                    "__ocml_expm1_f64");
+  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+                                    "__ocml_floor_f64");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+                                  "__ocml_log_f64");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+                                    "__ocml_log10_f64");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+                                    "__ocml_log1p_f64");
+  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+                                   "__ocml_log2_f64");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+                                   "__ocml_pow_f64");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+                                    "__ocml_rsqrt_f64");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+                                  "__ocml_sin_f64");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+                                   "__ocml_sqrt_f64");
+  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+                                   "__ocml_tanh_f64");
+  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+                                  "__ocml_tan_f64");
+  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+                                  "__ocml_erf_f64");
+  // Single arith pattern that needs a ROCDL call, probably not
+  // worth creating a separate pass for it.
+  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+                                    "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+    : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+  ConvertMathToROCDLPass() = default;
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+  populateMathToROCDLConversionPatterns(converter, patterns);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+                      LLVM::SqrtOp>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}

diff  --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @arith_remf
+  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = arith.remf %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = arith.remf %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+  // CHECK-LABEL: func @math_absf
+  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.absf %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.absf %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acos
+  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acosh
+  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asin
+  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asinh
+  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atan
+  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.atan %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atanh
+  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atanh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.atanh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @math_atan2
+  func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan2 %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = math.atan2 %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
+  // CHECK-LABEL: func @math_cbrt
+  func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.cbrt %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.cbrt %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
+  // CHECK-LABEL: func @math_ceil
+  func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.ceil %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.ceil %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
+  // CHECK-LABEL: func @math_cos
+  func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.cos %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.cos %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_cosh
+  func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.cosh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.cosh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_sinh
+  func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.sinh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.sinh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
+  // CHECK-LABEL: func @math_exp
+  func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.exp %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.exp %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
+  // CHECK-LABEL: func @math_exp2
+  func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.exp2 %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.exp2 %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
+  // CHECK-LABEL: func @math_expm1
+  func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.expm1 %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.expm1 %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
+  // CHECK-LABEL: func @math_floor
+  func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.floor %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.floor %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
+  // CHECK-LABEL: func @math_log
+  func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.log %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.log %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
+  // CHECK-LABEL: func @math_log10
+  func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.log10 %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.log10 %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
+  // CHECK-LABEL: func @math_log1p
+  func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.log1p %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.log1p %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @math_powf
+  func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.powf %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = math.powf %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
+  // CHECK-LABEL: func @math_rsqrt
+  func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.rsqrt %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.rsqrt %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+  // CHECK-LABEL: func @math_sin
+  func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.sin %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.sin %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
+  // CHECK-LABEL: func @math_sqrt
+  func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.sqrt %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.sqrt %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_tanh
+  func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.tanh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.tanh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
+  // CHECK-LABEL: func @math_tan
+  func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.tan %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.tan %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
+  // CHECK-LABEL: func @math_erf
+  func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.erf %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.erf %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @arith_remf
+  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = arith.remf %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = arith.remf %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+


        


More information about the Mlir-commits mailing list