[Mlir-commits] [mlir] [mlir][AMDGPU] Add support for AMD f16 math library calls (PR #108809)
Daniel Hernandez-Juarez
llvmlistbot at llvm.org
Mon Sep 16 02:51:27 PDT 2024
https://github.com/dhernandez0 created https://github.com/llvm/llvm-project/pull/108809
In this PR we add support for AMD f16 math library calls (__ocml_*_f16)
>From 1812c25be457d7ead1dd024f62cf83d437696472 Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Mon, 16 Sep 2024 09:42:42 +0000
Subject: [PATCH] Add support for AMD f16 math library calls
---
.../GPUCommon/OpToFuncCallLowering.h | 20 +-
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 4 +-
.../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 4 +-
.../Conversion/MathToROCDL/MathToROCDL.cpp | 61 +++---
.../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 163 ++++++++------
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 200 ++++++++++++------
6 files changed, 284 insertions(+), 168 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8a9414d32ec611 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,10 +17,10 @@
namespace mlir {
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the fastMathFlag of that
/// Op. The function declaration is added in case it was not added before.
///
-/// If the input values are of f16 type, the value is first casted to f32, the
+/// If the input values are of unsupported type, the value is first casted to f32, the
/// function called and then the result casted back.
///
/// Example with NVVM:
@@ -41,9 +41,9 @@ template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
- StringRef f64Func, StringRef f32ApproxFunc)
+ StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func)
: ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
- f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+ f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +89,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Type type = operand.getType();
- if (!isa<Float16Type>(type))
+ if (!isa<FloatType>(type))
+ return operand;
+
+ // if there's a f16 function, no need to cast f16 values
+ if (!f16Func.empty() && isa<Float16Type>(type))
+ return operand;
+
+ if (isa<Float64Type>(type) || isa<Float32Type>(type))
return operand;
return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
}
StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+ if (isa<Float16Type>(type))
+ return f16Func;
if (isa<Float32Type>(type)) {
if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
!f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
const std::string f32Func;
const std::string f64Func;
const std::string f32ApproxFunc;
+ const std::string f16Func;
};
} // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 164622d77e6b62..f5650c35c3b3c4 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -336,10 +336,10 @@ template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
- StringRef f32ApproxFunc = "") {
+ StringRef f32ApproxFunc = "", StringRef f16Func = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc);
+ f32ApproxFunc, f16Func);
}
void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..6b9e6b1192e050 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -346,9 +346,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
- StringRef f64Func) {
+ StringRef f64Func, StringRef f32ApproxFunc, StringRef f16Func) {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
- patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc, f16Func);
}
void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..1611a8835c91ef 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -39,16 +39,17 @@ template <typename OpTy>
static void populateOpPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns, StringRef f32Func,
StringRef f64Func,
+ StringRef f16Func,
StringRef f32ApproxFunc = "") {
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
- f32ApproxFunc);
+ f32ApproxFunc, f16Func);
}
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
- // Handled by mathToLLVM: math::AbsFIOp
+ // Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +64,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
- "__ocml_acos_f64");
+ "__ocml_acos_f64", "__ocml_acos_f16");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
- "__ocml_acosh_f64");
+ "__ocml_acosh_f64", "__ocml_acosh_f16");
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
- "__ocml_asin_f64");
+ "__ocml_asin_f64", "__ocml_asin_f16");
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
- "__ocml_asinh_f64");
+ "__ocml_asinh_f64", "__ocml_asinh_f16");
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64");
+ "__ocml_atan_f64", "__ocml_atan_f16");
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
- "__ocml_atanh_f64");
+ "__ocml_atanh_f64", "__ocml_atanh_f16");
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64");
+ "__ocml_atan2_f64", "__ocml_atan2_f16");
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64");
+ "__ocml_cbrt_f64", "__ocml_cbrt_f16");
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
+ "__ocml_ceil_f64", "__ocml_ceil_f16");
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64");
+ "__ocml_cos_f64", "__ocml_cos_f16");
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
- "__ocml_cosh_f64");
+ "__ocml_cosh_f64", "__ocml_cosh_f16");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
- "__ocml_sinh_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+ "__ocml_sinh_f64", "__ocml_sinh_f16");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "",
+ "__ocml_exp_f64", "__ocml_exp_f16");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
+ "__ocml_exp2_f64", "__ocml_exp2_f16");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64");
+ "__ocml_expm1_f64", "__ocml_expm1_f16");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+ "__ocml_floor_f64", "__ocml_floor_f16");
+ populateOpPatterns<math::LogOp>(converter, patterns, "",
+ "__ocml_log_f64", "__ocml_log_f16");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64");
+ "__ocml_log10_f64", "__ocml_log10_f16");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64");
+ "__ocml_log1p_f64", "__ocml_log1p_f16");
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64");
+ "__ocml_log2_f64", "__ocml_log2_f16");
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64");
+ "__ocml_pow_f64", "__ocml_pow_f16");
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64");
+ "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64");
+ "__ocml_sin_f64", "__ocml_sin_f16");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
+ "__ocml_tanh_f64", "__ocml_tanh_f16");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64");
+ "__ocml_tan_f64", "__ocml_tan_f16");
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64");
+ "__ocml_erf_f64", "__ocml_erf_f16");
// Single arith pattern that needs a ROCDL call, probably not
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64");
+ "__ocml_fmod_f64", "__ocml_fmod_f16");
}
namespace {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp
func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = math.exp %arg_f16 : f16
- // CHECK: llvm.intr.exp(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log
func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = math.log %arg_f16 : f16
- // CHECK: llvm.intr.log(%{{.*}}) : (f16) -> f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32
%result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
// CHECK-LABEL: func @gpu_cbrt
- func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cbrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cbrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cbrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
// CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
// CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
// CHECK-LABEL: func @gpu_ceil
- func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.ceil %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
%result32 = math.ceil %arg_f32 : f32
// CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
%result64 = math.ceil %arg_f64 : f64
// CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
// CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
// CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
// CHECK-LABEL: func @gpu_floor
- func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.floor %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
%result32 = math.floor %arg_f32 : f32
// CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
%result64 = math.floor %arg_f64 : f64
// CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
// CHECK-LABEL: func @gpu_cos
- func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result64 : f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp2
- func.func @gpu_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.exp2 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
%exp2_f32 = math.exp2 %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
%result32 = math.exp2 %exp2_f32 : f32
// CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
%result64 = math.exp2 %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
+
// Test that we handled properly operation with SymbolTable other than module op
gpu.module @test_module {
"test.symbol_scope"() ({
// CHECK: test.symbol_scope
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @gpu_sin
- func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %sin_f32 = math.sin %arg_f32 : f32
+ func.func @gpu_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+ %result16 = math.sin %arg_f16 : f16
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.sin %sin_f32 : f32
- // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sin %arg_f64 : f64
+ %result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ %result64 = math.sin %arg_f64 : f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
"test.finish" () : () -> ()
}) : () -> ()
@@ -304,89 +311,102 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
// CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
// CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
// CHECK-LABEL: func @gpu_expm1
- func.func @gpu_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.expm1 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
%expm1_f32 = math.expm1 %arg_f32 : f32
// CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
%result32 = math.expm1 %expm1_f32 : f32
// CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
%result64 = math.expm1 %arg_f64 : f64
// CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f64 : f64) -> (f64) {
+ func.func @gpu_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ %result16 = math.log %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result64 : f64
+ func.return %result16, %result64 : f16, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log1p
- func.func @gpu_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log1p %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log1p %arg_f32 : f32
// CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
%result64 = math.log1p %arg_f64 : f64
// CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log10
- func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log10 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log10 %arg_f32 : f32
// CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
%result64 = math.log10 %arg_f64 : f64
// CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_log2_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log2_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log2
- func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_log2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log2 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log2_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log2 %arg_f32 : f32
// CHECK: llvm.call @__ocml_log2_f32(%{{.*}}) : (f32) -> f32
%result64 = math.log2 %arg_f64 : f64
// CHECK: llvm.call @__ocml_log2_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16
// CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
// CHECK-LABEL: func @gpu_rsqrt
- func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
- -> (f16, f32, f64) {
+ func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = math.rsqrt %arg_f16 : f16
- // CHECK: llvm.fpext %{{.*}} : f16 to f32
- // CHECK-NEXT: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
- // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
+ // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
%result32 = math.rsqrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.rsqrt %arg_f64 : f64
@@ -398,90 +418,108 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
// CHECK-LABEL: func @gpu_tan
- func.func @gpu_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
%result32 = math.tan %arg_f32 : f32
// CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
%result64 = math.tan %arg_f64 : f64
// CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
// CHECK-LABEL: func @gpu_tanh
- func.func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tanh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.tanh %arg_f32 : f32
// CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.tanh %arg_f64 : f64
// CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16
// CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
// CHECK-LABEL: func @gpu_atan
- func.func @gpu_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
%result32 = math.atan %arg_f32 : f32
// CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
%result64 = math.atan %arg_f64 : f64
// CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
// CHECK-LABEL: func @gpu_atan2
- func.func @gpu_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan2 %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}) : (f16, f16) -> f16
%result32 = math.atan2 %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}) : (f32, f32) -> f32
%result64 = math.atan2 %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
// CHECK-LABEL: func @gpu_pow
- func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_pow(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.powf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
%result32 = math.powf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.powf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16
// CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
// CHECK-LABEL: func @gpu_erf
- func.func @gpu_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.erf %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
%result32 = math.erf %arg_f32 : f32
// CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
%result64 = math.erf %arg_f64 : f64
// CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
@@ -543,9 +581,9 @@ gpu.module @test_module {
// -----
gpu.module @module {
-// CHECK-LABEL: @spirv_exp
+// CHECK-LABEL: @spirv_sin
// CHECK: llvm.call @__ocml_sin_f32
- spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
+ spirv.func @spirv_sin(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
%0 = math.sin %arg0 : vector<4xf32>
spirv.ReturnValue %0 : vector<4xf32>
}
@@ -602,15 +640,18 @@ gpu.module @test_module {
// -----
gpu.module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
// CHECK-LABEL: func @gpu_fmod
- func.func @gpu_fmod(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @gpu_fmod(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = arith.remf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
%result32 = arith.remf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = arith.remf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 19d89e03a7f483..7ba049533fc161 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,399 +1,461 @@
// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
// CHECK-LABEL: func @arith_remf
- func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = arith.remf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
%result32 = arith.remf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = arith.remf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_acos_f16(f16) -> f16
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
// CHECK-LABEL: func @math_acos
- func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.acos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
%result32 = math.acos %arg_f32 : f32
// CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
%result64 = math.acos %arg_f64 : f64
// CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_acosh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
// CHECK-LABEL: func @math_acosh
- func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.acosh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.acosh %arg_f32 : f32
// CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.acosh %arg_f64 : f64
// CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_asin_f16(f16) -> f16
// CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
// CHECK-LABEL: func @math_asin
- func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.asin %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
%result32 = math.asin %arg_f32 : f32
// CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
%result64 = math.asin %arg_f64 : f64
// CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_asinh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
// CHECK-LABEL: func @math_asinh
- func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.asinh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.asinh %arg_f32 : f32
// CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.asinh %arg_f64 : f64
// CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16
// CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
// CHECK-LABEL: func @math_atan
- func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
%result32 = math.atan %arg_f32 : f32
// CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
%result64 = math.atan %arg_f64 : f64
// CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_atanh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
// CHECK-LABEL: func @math_atanh
- func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atanh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.atanh %arg_f32 : f32
// CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.atanh %arg_f64 : f64
// CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_atan2
- func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.atan2 %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
%result32 = math.atan2 %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.atan2 %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
// CHECK-LABEL: func @math_cbrt
- func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cbrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cbrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cbrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
// CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
// CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
// CHECK-LABEL: func @math_ceil
- func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.ceil %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
%result32 = math.ceil %arg_f32 : f32
// CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
%result64 = math.ceil %arg_f64 : f64
// CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
// CHECK-LABEL: func @math_cos
- func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cos %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cos %arg_f32 : f32
// CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cos %arg_f64 : f64
// CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_cosh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
// CHECK-LABEL: func @math_cosh
- func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.cosh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.cosh %arg_f32 : f32
// CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.cosh %arg_f64 : f64
// CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_sinh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
// CHECK-LABEL: func @math_sinh
- func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.sinh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.sinh %arg_f32 : f32
// CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.sinh %arg_f64 : f64
// CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @math_exp
- func.func @math_exp(%arg_f64 : f64) -> (f64) {
+ func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ %result16 = math.exp %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result64 : f64
+ func.return %result16, %result64 : f16, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
// CHECK-LABEL: func @math_exp2
- func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.exp2 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
%result32 = math.exp2 %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
%result64 = math.exp2 %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
// CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
// CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
// CHECK-LABEL: func @math_expm1
- func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.expm1 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
%result32 = math.expm1 %arg_f32 : f32
// CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
%result64 = math.expm1 %arg_f64 : f64
// CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
// CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
// CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
// CHECK-LABEL: func @math_floor
- func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.floor %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
%result32 = math.floor %arg_f32 : f32
// CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
%result64 = math.floor %arg_f64 : f64
// CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @math_log
- func.func @math_log(%arg_f64 : f64) -> (f64) {
+ func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+ %result16 = math.log %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result64 : f64
+ func.return %result16, %result64 : f16, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
// CHECK-LABEL: func @math_log10
- func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log10 %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log10 %arg_f32 : f32
// CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
%result64 = math.log10 %arg_f64 : f64
// CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
// CHECK-LABEL: func @math_log1p
- func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.log1p %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
%result32 = math.log1p %arg_f32 : f32
// CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
%result64 = math.log1p %arg_f64 : f64
// CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16
// CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
// CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
// CHECK-LABEL: func @math_powf
- func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.powf %arg_f16, %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
%result32 = math.powf %arg_f32, %arg_f32 : f32
// CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
%result64 = math.powf %arg_f64, %arg_f64 : f64
// CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16
// CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
// CHECK-LABEL: func @math_rsqrt
- func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.rsqrt %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
%result32 = math.rsqrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.rsqrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @math_sin
- func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.sin %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
%result32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
%result64 = math.sin %arg_f64 : f64
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
// CHECK-LABEL: func @math_tanh
- func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tanh %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
%result32 = math.tanh %arg_f32 : f32
// CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
%result64 = math.tanh %arg_f64 : f64
// CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
// CHECK-LABEL: func @math_tan
- func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.tan %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
%result32 = math.tan %arg_f32 : f32
// CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
%result64 = math.tan %arg_f64 : f64
// CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
}
// -----
module @test_module {
+ // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16
// CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
// CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
// CHECK-LABEL: func @math_erf
- func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+ %result16 = math.erf %arg_f16 : f16
+ // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
%result32 = math.erf %arg_f32 : f32
// CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
%result64 = math.erf %arg_f64 : f64
// CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
+ func.return %result16, %result32, %result64 : f16, f32, f64
}
-}
-
-// -----
-
-module @test_module {
- // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
- // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
- // CHECK-LABEL: func @arith_remf
- func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = arith.remf %arg_f32, %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
- %result64 = arith.remf %arg_f64, %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list