[Mlir-commits] [mlir] [mlir][AMDGPU] Add a scheduling barrier guard around inlineAsm lds.barrier (PR #109678)

Daniel Hernandez-Juarez llvmlistbot at llvm.org
Mon Sep 23 08:42:36 PDT 2024


https://github.com/dhernandez0 created https://github.com/llvm/llvm-project/pull/109678

This commit adds a scheduling region around the inlineAsm to guard against possible complications arising from them interfering with the backend scheduler / register allocation.

CC: @manupak @krzysz00 

>From 7d0fc51c28141370f673c43aee49898d6a6594e5 Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Mon, 16 Sep 2024 09:42:42 +0000
Subject: [PATCH 1/3] Add support for AMD f16 math library calls

---
 .../GPUCommon/OpToFuncCallLowering.h          |  24 +-
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |   6 +-
 .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp      |  11 +-
 .../Conversion/MathToROCDL/MathToROCDL.cpp    |  62 +++---
 .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir   | 163 +++++++++-----
 .../Conversion/MathToROCDL/math-to-rocdl.mlir | 210 ++++++++++++------
 6 files changed, 307 insertions(+), 169 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 6be5548fdb60ef..8ff4d4ec67b9fd 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -17,11 +17,13 @@
 namespace mlir {
 
 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` or
-/// `f32ApproxFunc` depending on the element type and the fastMathFlag of that
-/// Op. The function declaration is added in case it was not added before.
+/// `f32ApproxFunc` or `f16Func` depending on the element type and the
+/// fastMathFlag of that Op. The function declaration is added in case it was
+/// not added before.
 ///
-/// If the input values are of f16 type, the value is first casted to f32, the
-/// function called and then the result casted back.
+/// If the input values are of bf16 type (or f16 type if f16Func is empty), the
+/// value is first casted to f32, the function called and then the result casted
+/// back.
 ///
 /// Example with NVVM:
 ///   %exp_f32 = math.exp %arg_f32 : f32
@@ -41,9 +43,10 @@ template <typename SourceOp>
 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func,
-                                StringRef f64Func, StringRef f32ApproxFunc)
+                                StringRef f64Func, StringRef f32ApproxFunc,
+                                StringRef f16Func)
       : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func),
-        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc) {}
+        f64Func(f64Func), f32ApproxFunc(f32ApproxFunc), f16Func(f16Func) {}
 
   LogicalResult
   matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
@@ -89,7 +92,11 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
     Type type = operand.getType();
-    if (!isa<Float16Type>(type))
+    if (!isa<Float16Type, BFloat16Type>(type))
+      return operand;
+
+    // if there's a f16 function, no need to cast f16 values
+    if (!f16Func.empty() && isa<Float16Type>(type))
       return operand;
 
     return rewriter.create<LLVM::FPExtOp>(
@@ -102,6 +109,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   }
 
   StringRef getFunctionName(Type type, arith::FastMathFlags flag) const {
+    if (isa<Float16Type>(type))
+      return f16Func;
     if (isa<Float32Type>(type)) {
       if (((uint32_t)arith::FastMathFlags::afn & (uint32_t)flag) &&
           !f32ApproxFunc.empty())
@@ -130,6 +139,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   const std::string f32Func;
   const std::string f64Func;
   const std::string f32ApproxFunc;
+  const std::string f16Func;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 4be330b0bb26bb..2b91a6c28c05e8 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -335,11 +335,11 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
-                               StringRef f32ApproxFunc = "") {
+                               StringRef f64Func, StringRef f32ApproxFunc = "",
+                               StringRef f16Func = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateGpuSubgroupReduceOpLoweringPattern(
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index fc3e1fc4f9d0c9..482c9e2c2d0017 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -334,10 +334,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
                       LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
-  // These ops are legal for f16 and f32 type.
+  // These ops are legal for f32 type.
   target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
-    return any_of(op->getOperandTypes(),
-                  llvm::IsaPred<Float16Type, Float32Type>);
+    return any_of(op->getOperandTypes(), llvm::IsaPred<Float32Type>);
   });
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
@@ -346,9 +345,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func) {
+                               StringRef f64Func, StringRef f32ApproxFunc,
+                               StringRef f16Func) {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
-  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f32ApproxFunc,
+                                           f16Func);
 }
 
 void mlir::populateGpuToROCDLConversionPatterns(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index b3b4d81e7ffa5b..8330713ea66e5c 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -38,17 +38,17 @@ using namespace mlir;
 template <typename OpTy>
 static void populateOpPatterns(LLVMTypeConverter &converter,
                                RewritePatternSet &patterns, StringRef f32Func,
-                               StringRef f64Func,
+                               StringRef f64Func, StringRef f16Func,
                                StringRef f32ApproxFunc = "") {
   patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
   patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func,
-                                           f32ApproxFunc);
+                                           f32ApproxFunc, f16Func);
 }
 
 void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                  RewritePatternSet &patterns) {
   // Handled by mathToLLVM: math::AbsIOp
-  // Handled by mathToLLVM: math::AbsFIOp
+  // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
   // Handled by mathToLLVM: math::CountLeadingZerosOp
   // Handled by mathToLLVM: math::CountTrailingZerosOp
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64");
+                                   "__ocml_acos_f64", "__ocml_acos_f16");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64");
+                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64");
+                                   "__ocml_asin_f64", "__ocml_asin_f16");
   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64");
+                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
+                                   "__ocml_atan_f64", "__ocml_atan_f16");
   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64");
+                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
+                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
+                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
+                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
+                                  "__ocml_cos_f64", "__ocml_cos_f16");
   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64");
+                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
+                                  "__ocml_exp_f16");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
+                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
+                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+                                    "__ocml_floor_f64", "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
+                                  "__ocml_log_f16");
   populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
+                                    "__ocml_log10_f64", "__ocml_log10_f16");
   populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
+                                    "__ocml_log1p_f64", "__ocml_log1p_f16");
   populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
+                                   "__ocml_log2_f64", "__ocml_log2_f16");
   populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
+                                   "__ocml_pow_f64", "__ocml_pow_f16");
   populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
+                                    "__ocml_rsqrt_f64", "__ocml_rsqrt_f16");
   populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
+                                  "__ocml_sin_f64", "__ocml_sin_f16");
   populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
+                                   "__ocml_tanh_f64", "__ocml_tanh_f16");
   populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
+                                  "__ocml_tan_f64", "__ocml_tan_f16");
   populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+                                  "__ocml_erf_f64", "__ocml_erf_f16");
   // Single arith pattern that needs a ROCDL call, probably not
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
+                                    "__ocml_fmod_f64", "__ocml_fmod_f16");
 }
 
 namespace {
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index eb065cbab86789..0d3e9f4ea2bf39 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -162,11 +162,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp
   func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.exp %arg_f16 : f16
-    // CHECK: llvm.intr.exp(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.exp %arg_f32 : f32
     // CHECK: llvm.intr.exp(%{{.*}})  : (f32) -> f32
     %result64 = math.exp %arg_f64 : f64
@@ -178,11 +179,12 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log
   func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.log %arg_f16 : f16
-    // CHECK: llvm.intr.log(%{{.*}})  : (f16) -> f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log %arg_f32 : f32
     // CHECK: llvm.intr.log(%{{.*}})  : (f32) -> f32
     %result64 = math.log %arg_f64 : f64
@@ -194,108 +196,113 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cbrt
-  func.func @gpu_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cbrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cbrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cbrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_ceil
-  func.func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.ceil %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.ceil %arg_f32 : f32
     // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.ceil %arg_f64 : f64
     // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_floor
-  func.func @gpu_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.floor %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.floor %arg_f32 : f32
     // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.floor %arg_f64 : f64
     // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_cos
-  func.func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cos %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cos %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
-  }
-}
-
-// -----
-
-gpu.module @test_module {
-  // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
-  // CHECK-LABEL: func @gpu_exp
-  func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
-    %result64 = math.exp %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result64 : f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_exp2
-  func.func @gpu_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.exp2 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
     %exp2_f32 = math.exp2 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
     %result32 = math.exp2 %exp2_f32 : f32
     // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.exp2 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
+
 // Test that we handled properly operation with SymbolTable other than module op
 gpu.module @test_module {
   "test.symbol_scope"() ({
     // CHECK: test.symbol_scope
+    // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
     // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
     // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
     // CHECK-LABEL: func @gpu_sin
-    func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-      %sin_f32 = math.sin %arg_f32 : f32
+    func.func @gpu_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+      // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+      %result16 = math.sin %arg_f16 : f16
       // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
-      %result32 = math.sin %sin_f32 : f32
-      // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
-      %result64 = math.sin %arg_f64 : f64
+      %result32 = math.sin %arg_f32 : f32
       // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
-      func.return %result32, %result64 : f32, f64
+      %result64 = math.sin %arg_f64 : f64
+      func.return %result16, %result32, %result64 : f16, f32, f64
     }
     "test.finish" () : () -> ()
   }) : () -> ()
@@ -304,89 +311,102 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_expm1
-  func.func @gpu_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.expm1 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
     %expm1_f32 = math.expm1 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
     %result32 = math.expm1 %expm1_f32 : f32
     // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.expm1 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log
-  func.func @gpu_log(%arg_f64 : f64) -> (f64) {
+  func.func @gpu_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+    %result16 = math.log %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result64 = math.log %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result64 : f64
+    func.return %result16, %result64 : f16, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log1p
-  func.func @gpu_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.log1p %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log1p %arg_f32 : f32
     // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.log1p %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log10
-  func.func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.log10 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log10 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.log10 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_log2_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log2_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_log2
-  func.func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_log2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.log2 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log2_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log2 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_log2_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.log2 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log2_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_rsqrt
-  func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
-      -> (f16, f32, f64) {
+  func.func @gpu_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
     %result16 = math.rsqrt %arg_f16 : f16
-    // CHECK: llvm.fpext %{{.*}} : f16 to f32
-    // CHECK-NEXT: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
-    // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
+    // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.rsqrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.rsqrt %arg_f64 : f64
@@ -398,90 +418,108 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_tan
-  func.func @gpu_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.tan %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.tan %arg_f32 : f32
     // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.tan %arg_f64 : f64
     // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_tanh
-  func.func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.tanh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.tanh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.tanh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_atan
-  func.func @gpu_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.atan %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.atan %arg_f32 : f32
     // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.atan %arg_f64 : f64
     // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
   // CHECK-LABEL: func @gpu_atan2
-  func.func @gpu_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.atan2 %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}) : (f16, f16) -> f16
     %result32 = math.atan2 %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}) : (f32, f32) -> f32
     %result64 = math.atan2 %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
   // CHECK-LABEL: func @gpu_pow
-  func.func @gpu_pow(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_pow(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.powf %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
     %result32 = math.powf %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = math.powf %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
   // CHECK-LABEL: func @gpu_erf
-  func.func @gpu_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.erf %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.erf %arg_f32 : f32
     // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.erf %arg_f64 : f64
     // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
@@ -543,9 +581,9 @@ gpu.module @test_module {
 // -----
 
 gpu.module @module {
-// CHECK-LABEL: @spirv_exp
+// CHECK-LABEL: @spirv_sin
 // CHECK: llvm.call @__ocml_sin_f32
-  spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
+  spirv.func @spirv_sin(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
     %0 = math.sin %arg0 : vector<4xf32>
     spirv.ReturnValue %0 : vector<4xf32>
   }
@@ -602,15 +640,18 @@ gpu.module @test_module {
 // -----
 
 gpu.module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
   // CHECK-LABEL: func @gpu_fmod
-  func.func @gpu_fmod(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @gpu_fmod(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = arith.remf %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
     %result32 = arith.remf %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = arith.remf %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 19d89e03a7f483..7a56f44b137463 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,399 +1,483 @@
 // RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
   // CHECK-LABEL: func @arith_remf
-  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @arith_remf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = arith.remf %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_fmod_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
     %result32 = arith.remf %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = arith.remf %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_acos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
   // CHECK-LABEL: func @math_acos
-  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_acos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.acos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_acos_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.acos %arg_f32 : f32
     // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.acos %arg_f64 : f64
     // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_acosh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
   // CHECK-LABEL: func @math_acosh
-  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_acosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.acosh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_acosh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.acosh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.acosh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_asin_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
   // CHECK-LABEL: func @math_asin
-  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_asin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.asin %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_asin_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.asin %arg_f32 : f32
     // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.asin %arg_f64 : f64
     // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_asinh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
   // CHECK-LABEL: func @math_asinh
-  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_asinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.asinh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_asinh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.asinh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.asinh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_atan_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
   // CHECK-LABEL: func @math_atan
-  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_atan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.atan %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_atan_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.atan %arg_f32 : f32
     // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.atan %arg_f64 : f64
     // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_atanh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
   // CHECK-LABEL: func @math_atanh
-  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_atanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.atanh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_atanh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.atanh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.atanh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_atan2_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64
   // CHECK-LABEL: func @math_atan2
-  func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_atan2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.atan2 %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_atan2_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
     %result32 = math.atan2 %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = math.atan2 %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_cbrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
   // CHECK-LABEL: func @math_cbrt
-  func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_cbrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cbrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cbrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cbrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cbrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_ceil_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64
   // CHECK-LABEL: func @math_ceil
-  func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_ceil(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.ceil %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_ceil_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.ceil %arg_f32 : f32
     // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.ceil %arg_f64 : f64
     // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_cos_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64
   // CHECK-LABEL: func @math_cos
-  func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_cos(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cos %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cos_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cos %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cos %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_cosh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64
   // CHECK-LABEL: func @math_cosh
-  func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_cosh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.cosh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_cosh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.cosh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.cosh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_sinh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64
   // CHECK-LABEL: func @math_sinh
-  func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_sinh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.sinh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_sinh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.sinh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.sinh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_exp_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
   // CHECK-LABEL: func @math_exp
-  func.func @math_exp(%arg_f64 : f64) -> (f64) {
+  func.func @math_exp(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+    %result16 = math.exp %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_exp_f16(%{{.*}}) : (f16) -> f16
     %result64 = math.exp %arg_f64 : f64
     // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
-    func.return %result64 : f64
+    func.return %result16, %result64 : f16, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_exp2_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
   // CHECK-LABEL: func @math_exp2
-  func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_exp2(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.exp2 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_exp2_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.exp2 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.exp2 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_expm1_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64
   // CHECK-LABEL: func @math_expm1
-  func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_expm1(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.expm1 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_expm1_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.expm1 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.expm1 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_floor_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64
   // CHECK-LABEL: func @math_floor
-  func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_floor(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.floor %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_floor_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.floor %arg_f32 : f32
     // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.floor %arg_f64 : f64
     // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_log_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
   // CHECK-LABEL: func @math_log
-  func.func @math_log(%arg_f64 : f64) -> (f64) {
+  func.func @math_log(%arg_f16 : f16, %arg_f64 : f64) -> (f16, f64) {
+    %result16 = math.log %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log_f16(%{{.*}}) : (f16) -> f16
     %result64 = math.log %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
-    func.return %result64 : f64
+    func.return %result16, %result64 : f16, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_log10_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
   // CHECK-LABEL: func @math_log10
-  func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_log10(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.log10 %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log10_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log10 %arg_f32 : f32
     // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.log10 %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_log1p_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
   // CHECK-LABEL: func @math_log1p
-  func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_log1p(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.log1p %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_log1p_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.log1p %arg_f32 : f32
     // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.log1p %arg_f64 : f64
     // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_pow_f16(f16, f16) -> f16
   // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32
   // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64
   // CHECK-LABEL: func @math_powf
-  func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_powf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.powf %arg_f16, %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_pow_f16(%{{.*}}, %{{.*}}) : (f16, f16) -> f16
     %result32 = math.powf %arg_f32, %arg_f32 : f32
     // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
     %result64 = math.powf %arg_f64, %arg_f64 : f64
     // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_rsqrt_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64
   // CHECK-LABEL: func @math_rsqrt
-  func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_rsqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.rsqrt %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_rsqrt_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.rsqrt %arg_f32 : f32
     // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.rsqrt %arg_f64 : f64
     // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
   // CHECK-LABEL: func @math_sin
-  func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_sin(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.sin %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.sin %arg_f32 : f32
     // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.sin %arg_f64 : f64
     // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_tanh_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
   // CHECK-LABEL: func @math_tanh
-  func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.tanh %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_tanh_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.tanh %arg_f32 : f32
     // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.tanh %arg_f64 : f64
     // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_tan_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
   // CHECK-LABEL: func @math_tan
-  func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_tan(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.tan %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_tan_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.tan %arg_f32 : f32
     // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.tan %arg_f64 : f64
     // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
+  // CHECK: llvm.func @__ocml_erf_f16(f16) -> f16
   // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32
   // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64
   // CHECK-LABEL: func @math_erf
-  func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+  func.func @math_erf(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
+    %result16 = math.erf %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_erf_f16(%{{.*}}) : (f16) -> f16
     %result32 = math.erf %arg_f32 : f32
     // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32
     %result64 = math.erf %arg_f64 : f64
     // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64
-    func.return %result32, %result64 : f32, f64
+    func.return %result16, %result32, %result64 : f16, f32, f64
   }
 }
 
 // -----
 
 module @test_module {
-  // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
-  // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
-  // CHECK-LABEL: func @arith_remf
-  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
-    %result32 = arith.remf %arg_f32, %arg_f32 : f32
-    // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
-    %result64 = arith.remf %arg_f64, %arg_f64 : f64
-    // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
-    func.return %result32, %result64 : f32, f64
+  // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16
+  // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+  // CHECK-LABEL: func @math_casting
+  func.func @math_casting(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64, %arg_bf16 : bf16) -> (f16, f32, f64, bf16) {
+    %resultf16 = math.sin %arg_f16 : f16
+    // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16
+    %resultf32 = math.sin %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+    %resultf64 = math.sin %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
+    %resultbf16 = math.sin %arg_bf16 : bf16
+    // CHECK: llvm.fpext %{{.*}} : bf16 to f32
+    // CHECK-NEXT: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+    // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
+    func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16
   }
-}
-
+}
\ No newline at end of file

>From ccc36b8d0f560b19c4c1dd9b34a0262d8d99a900 Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Mon, 23 Sep 2024 15:02:09 +0000
Subject: [PATCH 2/3] Address PR comments: new line

---
 mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 7a56f44b137463..ddd96bf797e6e7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -480,4 +480,4 @@ module @test_module {
     // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to bf16
     func.return %resultf16, %resultf32, %resultf64, %resultbf16 : f16, f32, f64, bf16
   }
-}
\ No newline at end of file
+}

>From 24b4bc8f587215d0486724b2f1e6ed2807e3b0dd Mon Sep 17 00:00:00 2001
From: Daniel Hernandez-Juarez <dhernandez0 at gmail.com>
Date: Mon, 23 Sep 2024 15:36:38 +0000
Subject: [PATCH 3/3] Add a scheduling barrier guard around inlineAsm
 lds.barrier

This commit adds a scheduling regions around the inlineAsm
to guard against possible complications arising from them
interfering with the backend scheduler / register allocation.
---
 mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp  | 12 ++++++++++--
 .../Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir    |  4 ++++
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index f80d2793eaef59..fdeedc0a91e307 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -290,15 +290,23 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
     if (requiresInlineAsm) {
       auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
                                                       LLVM::AsmDialect::AD_ATT);
+      Location loc = op->getLoc();
+      // Ensure the inlineAsm is guarded with a scheduling region
+      // So it will not interfere with backend compilation more than
+      // it needs.
+      rewriter.create<amdgpu::SchedBarrierOp>(
+          loc, amdgpu::sched_barrier_opt_enum::none);
       const char *asmStr =
           ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
       const char *constraints = "";
-      rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
-          op,
+      rewriter.create<LLVM::InlineAsmOp>(
+          loc,
           /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
           /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
           /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
           /*operand_attrs=*/ArrayAttr());
+      rewriter.replaceOpWithNewOp<amdgpu::SchedBarrierOp>(
+          op, amdgpu::sched_barrier_opt_enum::none);
       return success();
     }
     constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 9f4db151043455..e10c1d6c7df04c 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -238,14 +238,18 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_v2f16(%src : vector<2xf16>, %cmp : v
 
 // CHECK-LABEL: func @lds_barrier
 func.func @lds_barrier() {
+  // GFX908: rocdl.sched.barrier 0
   // GFX908: llvm.inline_asm has_side_effects asm_dialect = att
   // GFX908-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX908: rocdl.sched.barrier 0
   // GFX90A: rocdl.waitcnt -7937
   // GFX90A-NEXT: rocdl.s.barrier
   // GFX10:  rocdl.waitcnt -16129
   // GFX10-NEXT: rocdl.s.barrier
+  // GFX11: rocdl.sched.barrier 0
   // GFX11:  llvm.inline_asm has_side_effects asm_dialect = att
   // GFX11-SAME: ";;;WARNING: BREAKS DEBUG WATCHES\0As_waitcnt lgkmcnt(0)\0As_barrier"
+  // GFX11: rocdl.sched.barrier 0
   amdgpu.lds_barrier
   func.return
 }



More information about the Mlir-commits mailing list