[Mlir-commits] [mlir] [mlir][AMDGPU] Add support for AMD f16 math library calls (PR #108809)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Sep 19 11:35:29 PDT 2024


================
@@ -63,59 +63,61 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
   // Handled by mathToLLVM: math::SqrtOp
   // Handled by mathToLLVM: math::TruncOp
   populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
-                                   "__ocml_acos_f64");
+                                   "__ocml_acos_f64", "__ocml_acos_f16");
   populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
-                                    "__ocml_acosh_f64");
+                                    "__ocml_acosh_f64", "__ocml_acosh_f16");
   populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
-                                   "__ocml_asin_f64");
+                                   "__ocml_asin_f64", "__ocml_asin_f16");
   populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
-                                    "__ocml_asinh_f64");
+                                    "__ocml_asinh_f64", "__ocml_asinh_f16");
   populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
+                                   "__ocml_atan_f64", "__ocml_atan_f16");
   populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
-                                    "__ocml_atanh_f64");
+                                    "__ocml_atanh_f64", "__ocml_atanh_f16");
   populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
+                                    "__ocml_atan2_f64", "__ocml_atan2_f16");
   populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
+                                   "__ocml_cbrt_f64", "__ocml_cbrt_f16");
   populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
+                                   "__ocml_ceil_f64", "__ocml_ceil_f16");
   populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
+                                  "__ocml_cos_f64", "__ocml_cos_f16");
   populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
-                                   "__ocml_cosh_f64");
+                                   "__ocml_cosh_f64", "__ocml_cosh_f16");
   populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
-                                   "__ocml_sinh_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
+                                   "__ocml_sinh_f64", "__ocml_sinh_f16");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64",
+                                  "__ocml_exp_f16");
   populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
+                                   "__ocml_exp2_f64", "__ocml_exp2_f16");
   populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
+                                    "__ocml_expm1_f64", "__ocml_expm1_f16");
   populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
+                                    "__ocml_floor_f64", "__ocml_floor_f16");
+  populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64",
+                                  "__ocml_log_f16");
----------------
krzysz00 wrote:

This log implementation is definitely non-trivial
```
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none) define linkonce_odr protected half @__ocml_log_f16(half noundef %0) local_unnamed_addr #0 {
  %2 = fpext half %0 to float
  %3 = tail call float @llvm.amdgcn.log.f32(float %2)
  %4 = fmul float %3, 0x3FE62E4300000000
  %5 = fptrunc float %4 to half
  ret half %5
}; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(none) define linkonce_odr protected half @__ocml_log_f16(half noundef %0) local_unnamed_addr #0 {
  %2 = fpext half %0 to float
  %3 = tail call float @llvm.amdgcn.log.f32(float %2)
  %4 = fmul float %3, 0x3FE62E4300000000
  %5 = fptrunc float %4 to half
  ret half %5
}
```
so we might want to leave it in (the f32 one just calls an intrinsic)

https://github.com/llvm/llvm-project/pull/108809


More information about the Mlir-commits mailing list