[clang] [llvm] [WebAssembly] Implement all f16x8 binary instructions. (PR #93360)

Brendan Dahl via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 15:35:37 PDT 2024


https://github.com/brendandahl updated https://github.com/llvm/llvm-project/pull/93360

>From c33801afebb6720bc4b51fb4064b59529c40d298 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Thu, 23 May 2024 23:38:51 +0000
Subject: [PATCH 1/2] [WebAssembly] Implement all f16x8 binary instructions.

This reuses most of the code that was created for f32x4 and f64x2 binary
instructions and tries to follow how they were implemented.

add/sub/mul/div - use regular LL instructions
min/max - use the minimum/maximum intrinsic, and also have builtins
pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins

Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md
---
 .../clang/Basic/BuiltinsWebAssembly.def       |  4 ++
 clang/lib/CodeGen/CGBuiltin.cpp               |  4 ++
 clang/test/CodeGen/builtins-wasm.c            | 24 +++++++
 .../WebAssembly/WebAssemblyISelLowering.cpp   |  5 ++
 .../WebAssembly/WebAssemblyInstrSIMD.td       | 37 +++++++---
 .../CodeGen/WebAssembly/half-precision.ll     | 68 +++++++++++++++++++
 llvm/test/MC/WebAssembly/simd-encodings.s     | 24 +++++++
 7 files changed, 157 insertions(+), 9 deletions(-)

diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index fd8c1b480d6da..4e48ff48b60f5 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -135,6 +135,10 @@ TARGET_BUILTIN(__builtin_wasm_min_f64x2, "V2dV2dV2d", "nc", "simd128")
 TARGET_BUILTIN(__builtin_wasm_max_f64x2, "V2dV2dV2d", "nc", "simd128")
 TARGET_BUILTIN(__builtin_wasm_pmin_f64x2, "V2dV2dV2d", "nc", "simd128")
 TARGET_BUILTIN(__builtin_wasm_pmax_f64x2, "V2dV2dV2d", "nc", "simd128")
+TARGET_BUILTIN(__builtin_wasm_min_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_max_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmin_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmax_f16x8, "V8hV8hV8h", "nc", "half-precision")
 
 TARGET_BUILTIN(__builtin_wasm_ceil_f32x4, "V4fV4f", "nc", "simd128")
 TARGET_BUILTIN(__builtin_wasm_floor_f32x4, "V4fV4f", "nc", "simd128")
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0549afa12e430..f8be7182b5267 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -20779,6 +20779,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
   }
   case WebAssembly::BI__builtin_wasm_min_f32:
   case WebAssembly::BI__builtin_wasm_min_f64:
+  case WebAssembly::BI__builtin_wasm_min_f16x8:
   case WebAssembly::BI__builtin_wasm_min_f32x4:
   case WebAssembly::BI__builtin_wasm_min_f64x2: {
     Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20789,6 +20790,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
   }
   case WebAssembly::BI__builtin_wasm_max_f32:
   case WebAssembly::BI__builtin_wasm_max_f64:
+  case WebAssembly::BI__builtin_wasm_max_f16x8:
   case WebAssembly::BI__builtin_wasm_max_f32x4:
   case WebAssembly::BI__builtin_wasm_max_f64x2: {
     Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20797,6 +20799,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
         CGM.getIntrinsic(Intrinsic::maximum, ConvertType(E->getType()));
     return Builder.CreateCall(Callee, {LHS, RHS});
   }
+  case WebAssembly::BI__builtin_wasm_pmin_f16x8:
   case WebAssembly::BI__builtin_wasm_pmin_f32x4:
   case WebAssembly::BI__builtin_wasm_pmin_f64x2: {
     Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20805,6 +20808,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
         CGM.getIntrinsic(Intrinsic::wasm_pmin, ConvertType(E->getType()));
     return Builder.CreateCall(Callee, {LHS, RHS});
   }
+  case WebAssembly::BI__builtin_wasm_pmax_f16x8:
   case WebAssembly::BI__builtin_wasm_pmax_f32x4:
   case WebAssembly::BI__builtin_wasm_pmax_f64x2: {
     Value *LHS = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c
index 93a6ab06081c9..d6ee4f68700dc 100644
--- a/clang/test/CodeGen/builtins-wasm.c
+++ b/clang/test/CodeGen/builtins-wasm.c
@@ -825,6 +825,30 @@ float extract_lane_f16x8(f16x8 a, int i) {
   // WEBASSEMBLY-NEXT: ret float %0
   return __builtin_wasm_extract_lane_f16x8(a, i);
 }
+
+f16x8 min_f16x8(f16x8 a, f16x8 b) {
+  // WEBASSEMBLY:  %0 = tail call <8 x half> @llvm.minimum.v8f16(<8 x half> %a, <8 x half> %b)
+  // WEBASSEMBLY-NEXT: ret <8 x half> %0
+  return __builtin_wasm_min_f16x8(a, b);
+}
+
+f16x8 max_f16x8(f16x8 a, f16x8 b) {
+  // WEBASSEMBLY:  %0 = tail call <8 x half> @llvm.maximum.v8f16(<8 x half> %a, <8 x half> %b)
+  // WEBASSEMBLY-NEXT: ret <8 x half> %0
+  return __builtin_wasm_max_f16x8(a, b);
+}
+
+f16x8 pmin_f16x8(f16x8 a, f16x8 b) {
+  // WEBASSEMBLY:  %0 = tail call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+  // WEBASSEMBLY-NEXT: ret <8 x half> %0
+  return __builtin_wasm_pmin_f16x8(a, b);
+}
+
+f16x8 pmax_f16x8(f16x8 a, f16x8 b) {
+  // WEBASSEMBLY:  %0 = tail call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+  // WEBASSEMBLY-NEXT: ret <8 x half> %0
+  return __builtin_wasm_pmax_f16x8(a, b);
+}
 __externref_t externref_null() {
   return __builtin_wasm_ref_null_extern();
   // WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 518b6932a0c87..7cbae1bef8ef4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -142,6 +142,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     setTruncStoreAction(T, MVT::f16, Expand);
   }
 
+  if (Subtarget->hasHalfPrecision()) {
+    setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal);
+    setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal);
+  }
+
   // Expand unavailable integer operations.
   for (auto Op :
        {ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 558e3d859dcd8..83260fbaa700b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -16,33 +16,34 @@
 multiclass ABSTRACT_SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
                            list<dag> pattern_r, string asmstr_r,
                            string asmstr_s, bits<32> simdop,
-                           Predicate simd_level> {
+                           list<Predicate> reqs> {
   defm "" : I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, asmstr_s,
               !if(!ge(simdop, 0x100),
                   !or(0xfd0000, !and(0xffff, simdop)),
                   !or(0xfd00, !and(0xff, simdop)))>,
-            Requires<[simd_level]>;
+            Requires<reqs>;
 }
 
 multiclass SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
                   list<dag> pattern_r, string asmstr_r = "",
-                  string asmstr_s = "", bits<32> simdop = -1> {
+                  string asmstr_s = "", bits<32> simdop = -1,
+                  list<Predicate> reqs = []> {
   defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
-                            asmstr_s, simdop, HasSIMD128>;
+                            asmstr_s, simdop, !listconcat([HasSIMD128], reqs)>;
 }
 
 multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
                      list<dag> pattern_r, string asmstr_r = "",
                      string asmstr_s = "", bits<32> simdop = -1> {
   defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
-                            asmstr_s, simdop, HasRelaxedSIMD>;
+                            asmstr_s, simdop, [HasRelaxedSIMD]>;
 }
 
 multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
                             list<dag> pattern_r, string asmstr_r = "",
                             string asmstr_s = "", bits<32> simdop = -1> {
   defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
-                            asmstr_s, simdop, HasHalfPrecision>;
+                            asmstr_s, simdop, [HasHalfPrecision]>;
 }
 
 
@@ -152,6 +153,18 @@ def F64x2 : Vec {
   let prefix = "f64x2";
 }
 
+def F16x8 : Vec {
+ let vt = v8f16;
+ let int_vt = v8i16;
+ let lane_vt = f32;
+ let lane_rc = F32;
+ let lane_bits = 16;
+ let lane_idx = LaneIdx8;
+ let lane_load = int_wasm_loadf16_f32;
+ let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>;
+ let prefix = "f16x8";
+}
+
 defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
 defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];
 
@@ -781,13 +794,14 @@ def : Pat<(v2i64 (nodes[0] (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
 // Bitwise operations
 //===----------------------------------------------------------------------===//
 
-multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
+multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
+                      bits<32> simdop, list<Predicate> reqs = []> {
   defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs),
                       (outs), (ins),
                       [(set (vec.vt V128:$dst),
                         (node (vec.vt V128:$lhs), (vec.vt V128:$rhs)))],
                       vec.prefix#"."#name#"\t$dst, $lhs, $rhs",
-                      vec.prefix#"."#name, simdop>;
+                      vec.prefix#"."#name, simdop, reqs>;
 }
 
 multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
@@ -1199,6 +1213,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
 multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
   defm "" : SIMDBinary<F32x4, node, name, baseInst>;
   defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
+  defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>;
 }
 
 // Addition: add
@@ -1242,7 +1257,7 @@ defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>;
 // Also match the pmin/pmax cases where the operands are int vectors (but the
 // comparison is still a floating point comparison). This can happen when using
 // the wasm_simd128.h intrinsics because v128_t is an integer vector.
-foreach vec = [F32x4, F64x2] in {
+foreach vec = [F32x4, F64x2, F16x8] in {
 defvar pmin = !cast<NI>("PMIN_"#vec);
 defvar pmax = !cast<NI>("PMAX_"#vec);
 def : Pat<(vec.int_vt (vselect
@@ -1266,6 +1281,10 @@ def : Pat<(v2f64 (int_wasm_pmin (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
           (PMIN_F64x2 V128:$lhs, V128:$rhs)>;
 def : Pat<(v2f64 (int_wasm_pmax (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
           (PMAX_F64x2 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmin (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+          (PMIN_F16x8 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmax (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+          (PMAX_F16x8 V128:$lhs, V128:$rhs)>;
 
 //===----------------------------------------------------------------------===//
 // Conversions
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index d9d3f6be800fd..73ccea8d652db 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -35,3 +35,71 @@ define float @extract_lane_v8f16(<8 x half> %v) {
   %r = call float @llvm.wasm.extract.lane.f16x8(<8 x half> %v, i32 1)
   ret float %r
 }
+
+; CHECK-LABEL: add_v8f16:
+; CHECK:       f16x8.add $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+define <8 x half> @add_v8f16(<8 x half> %a, <8 x half> %b) {
+  %r = fadd <8 x half> %a, %b
+  ret <8 x half> %r
+}
+
+; CHECK-LABEL: sub_v8f16:
+; CHECK:       f16x8.sub $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+define <8 x half> @sub_v8f16(<8 x half> %a, <8 x half> %b) {
+  %r = fsub <8 x half> %a, %b
+  ret <8 x half> %r
+}
+
+; CHECK-LABEL: mul_v8f16:
+; CHECK:       f16x8.mul $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+define <8 x half> @mul_v8f16(<8 x half> %a, <8 x half> %b) {
+  %r = fmul <8 x half> %a, %b
+  ret <8 x half> %r
+}
+
+; CHECK-LABEL: div_v8f16:
+; CHECK:       f16x8.div $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+define <8 x half> @div_v8f16(<8 x half> %a, <8 x half> %b) {
+  %r = fdiv <8 x half> %a, %b
+  ret <8 x half> %r
+}
+
+; CHECK-LABEL: min_intrinsic_v8f16:
+; CHECK:       f16x8.min $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.minimum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @min_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+  %a = call <8 x half> @llvm.minimum.v8f16(<8 x half> %x, <8 x half> %y)
+  ret <8 x half> %a
+}
+
+; CHECK-LABEL: max_intrinsic_v8f16:
+; CHECK:       f16x8.max $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.maximum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @max_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+  %a = call <8 x half> @llvm.maximum.v8f16(<8 x half> %x, <8 x half> %y)
+  ret <8 x half> %a
+}
+
+; CHECK-LABEL: pmin_intrinsic_v8f16:
+; CHECK:       f16x8.pmin $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.wasm.pmin.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmin_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+  %v = call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+  ret <8 x half> %v
+}
+
+; CHECK-LABEL: pmax_intrinsic_v8f16:
+; CHECK:       f16x8.pmax $push0=, $0, $1
+; CHECK-NEXT:  return $pop0
+declare <8 x half> @llvm.wasm.pmax.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmax_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+  %v = call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+  ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index d397188a9882e..113a23da776fa 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -851,4 +851,28 @@ main:
     # CHECK: f16x8.extract_lane 1 # encoding: [0xfd,0xa1,0x02,0x01]
     f16x8.extract_lane 1
 
+    # CHECK: f16x8.add # encoding: [0xfd,0xb4,0x02]
+    f16x8.add
+
+    # CHECK: f16x8.sub # encoding: [0xfd,0xb5,0x02]
+    f16x8.sub
+
+    # CHECK: f16x8.mul # encoding: [0xfd,0xb6,0x02]
+    f16x8.mul
+
+    # CHECK: f16x8.div # encoding: [0xfd,0xb7,0x02]
+    f16x8.div
+
+    # CHECK: f16x8.min # encoding: [0xfd,0xb8,0x02]
+    f16x8.min
+
+    # CHECK: f16x8.max # encoding: [0xfd,0xb9,0x02]
+    f16x8.max
+
+    # CHECK: f16x8.pmin # encoding: [0xfd,0xba,0x02]
+    f16x8.pmin
+
+    # CHECK: f16x8.pmax # encoding: [0xfd,0xbb,0x02]
+    f16x8.pmax
+
     end_function

>From 11da7bbce1a2df80825e8d14e603a07656209504 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Tue, 28 May 2024 22:35:21 +0000
Subject: [PATCH 2/2] Review comments.

---
 llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 83260fbaa700b..baf15ccdbe9ed 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -165,6 +165,7 @@ def F16x8 : Vec {
  let prefix = "f16x8";
 }
 
+// TODO: Include F16x8 here when half precision is better supported.
 defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
 defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];
 
@@ -804,6 +805,11 @@ multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
                       vec.prefix#"."#name, simdop, reqs>;
 }
 
+multiclass HalfPrecisionBinary<Vec vec, SDPatternOperator node, string name,
+                               bits<32> simdop> {
+  defm "" : SIMDBinary<vec, node, name, simdop, [HasHalfPrecision]>;
+}
+
 multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
                        bit commutable = false> {
   let isCommutable = commutable in
@@ -1213,7 +1219,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
 multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
   defm "" : SIMDBinary<F32x4, node, name, baseInst>;
   defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
-  defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>;
+  defm "" : HalfPrecisionBinary<F16x8, node, name, !add(baseInst, 80)>;
 }
 
 // Addition: add



More information about the llvm-commits mailing list