[clang] [llvm] [WebAssembly] Implement all f16x8 binary instructions. (PR #93360)
Brendan Dahl via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 15:35:37 PDT 2024
https://github.com/brendandahl updated https://github.com/llvm/llvm-project/pull/93360
>From c33801afebb6720bc4b51fb4064b59529c40d298 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Thu, 23 May 2024 23:38:51 +0000
Subject: [PATCH 1/2] [WebAssembly] Implement all f16x8 binary instructions.
This reuses most of the code that was created for f32x4 and f64x2 binary
instructions and tries to follow how they were implemented.
add/sub/mul/div - use regular LL instructions
min/max - use the minimum/maximum intrinsic, and also have builtins
pmin/pmax - use the wasm.pmax/pmin intrinsics and also have builtins
Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md
---
.../clang/Basic/BuiltinsWebAssembly.def | 4 ++
clang/lib/CodeGen/CGBuiltin.cpp | 4 ++
clang/test/CodeGen/builtins-wasm.c | 24 +++++++
.../WebAssembly/WebAssemblyISelLowering.cpp | 5 ++
.../WebAssembly/WebAssemblyInstrSIMD.td | 37 +++++++---
.../CodeGen/WebAssembly/half-precision.ll | 68 +++++++++++++++++++
llvm/test/MC/WebAssembly/simd-encodings.s | 24 +++++++
7 files changed, 157 insertions(+), 9 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index fd8c1b480d6da..4e48ff48b60f5 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -135,6 +135,10 @@ TARGET_BUILTIN(__builtin_wasm_min_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_max_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmin_f64x2, "V2dV2dV2d", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_pmax_f64x2, "V2dV2dV2d", "nc", "simd128")
+TARGET_BUILTIN(__builtin_wasm_min_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_max_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmin_f16x8, "V8hV8hV8h", "nc", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_pmax_f16x8, "V8hV8hV8h", "nc", "half-precision")
TARGET_BUILTIN(__builtin_wasm_ceil_f32x4, "V4fV4f", "nc", "simd128")
TARGET_BUILTIN(__builtin_wasm_floor_f32x4, "V4fV4f", "nc", "simd128")
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0549afa12e430..f8be7182b5267 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -20779,6 +20779,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_min_f32:
case WebAssembly::BI__builtin_wasm_min_f64:
+ case WebAssembly::BI__builtin_wasm_min_f16x8:
case WebAssembly::BI__builtin_wasm_min_f32x4:
case WebAssembly::BI__builtin_wasm_min_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20789,6 +20790,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
}
case WebAssembly::BI__builtin_wasm_max_f32:
case WebAssembly::BI__builtin_wasm_max_f64:
+ case WebAssembly::BI__builtin_wasm_max_f16x8:
case WebAssembly::BI__builtin_wasm_max_f32x4:
case WebAssembly::BI__builtin_wasm_max_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20797,6 +20799,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::maximum, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
+ case WebAssembly::BI__builtin_wasm_pmin_f16x8:
case WebAssembly::BI__builtin_wasm_pmin_f32x4:
case WebAssembly::BI__builtin_wasm_pmin_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
@@ -20805,6 +20808,7 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
CGM.getIntrinsic(Intrinsic::wasm_pmin, ConvertType(E->getType()));
return Builder.CreateCall(Callee, {LHS, RHS});
}
+ case WebAssembly::BI__builtin_wasm_pmax_f16x8:
case WebAssembly::BI__builtin_wasm_pmax_f32x4:
case WebAssembly::BI__builtin_wasm_pmax_f64x2: {
Value *LHS = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c
index 93a6ab06081c9..d6ee4f68700dc 100644
--- a/clang/test/CodeGen/builtins-wasm.c
+++ b/clang/test/CodeGen/builtins-wasm.c
@@ -825,6 +825,30 @@ float extract_lane_f16x8(f16x8 a, int i) {
// WEBASSEMBLY-NEXT: ret float %0
return __builtin_wasm_extract_lane_f16x8(a, i);
}
+
+f16x8 min_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.minimum.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_min_f16x8(a, b);
+}
+
+f16x8 max_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.maximum.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_max_f16x8(a, b);
+}
+
+f16x8 pmin_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_pmin_f16x8(a, b);
+}
+
+f16x8 pmax_f16x8(f16x8 a, f16x8 b) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_pmax_f16x8(a, b);
+}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 518b6932a0c87..7cbae1bef8ef4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -142,6 +142,11 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setTruncStoreAction(T, MVT::f16, Expand);
}
+ if (Subtarget->hasHalfPrecision()) {
+ setOperationAction(ISD::FMINIMUM, MVT::v8f16, Legal);
+ setOperationAction(ISD::FMAXIMUM, MVT::v8f16, Legal);
+ }
+
// Expand unavailable integer operations.
for (auto Op :
{ISD::BSWAP, ISD::SMUL_LOHI, ISD::UMUL_LOHI, ISD::MULHS, ISD::MULHU,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 558e3d859dcd8..83260fbaa700b 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -16,33 +16,34 @@
multiclass ABSTRACT_SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r,
string asmstr_s, bits<32> simdop,
- Predicate simd_level> {
+ list<Predicate> reqs> {
defm "" : I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r, asmstr_s,
!if(!ge(simdop, 0x100),
!or(0xfd0000, !and(0xffff, simdop)),
!or(0xfd00, !and(0xff, simdop)))>,
- Requires<[simd_level]>;
+ Requires<reqs>;
}
multiclass SIMD_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
- string asmstr_s = "", bits<32> simdop = -1> {
+ string asmstr_s = "", bits<32> simdop = -1,
+ list<Predicate> reqs = []> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasSIMD128>;
+ asmstr_s, simdop, !listconcat([HasSIMD128], reqs)>;
}
multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasRelaxedSIMD>;
+ asmstr_s, simdop, [HasRelaxedSIMD]>;
}
multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
list<dag> pattern_r, string asmstr_r = "",
string asmstr_s = "", bits<32> simdop = -1> {
defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
- asmstr_s, simdop, HasHalfPrecision>;
+ asmstr_s, simdop, [HasHalfPrecision]>;
}
@@ -152,6 +153,18 @@ def F64x2 : Vec {
let prefix = "f64x2";
}
+def F16x8 : Vec {
+ let vt = v8f16;
+ let int_vt = v8i16;
+ let lane_vt = f32;
+ let lane_rc = F32;
+ let lane_bits = 16;
+ let lane_idx = LaneIdx8;
+ let lane_load = int_wasm_loadf16_f32;
+ let splat = PatFrag<(ops node:$x), (v8f16 (splat_vector (f16 $x)))>;
+ let prefix = "f16x8";
+}
+
defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];
@@ -781,13 +794,14 @@ def : Pat<(v2i64 (nodes[0] (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
// Bitwise operations
//===----------------------------------------------------------------------===//
-multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name, bits<32> simdop> {
+multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
+ bits<32> simdop, list<Predicate> reqs = []> {
defm _#vec : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs),
(outs), (ins),
[(set (vec.vt V128:$dst),
(node (vec.vt V128:$lhs), (vec.vt V128:$rhs)))],
vec.prefix#"."#name#"\t$dst, $lhs, $rhs",
- vec.prefix#"."#name, simdop>;
+ vec.prefix#"."#name, simdop, reqs>;
}
multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
@@ -1199,6 +1213,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
defm "" : SIMDBinary<F32x4, node, name, baseInst>;
defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
+ defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>;
}
// Addition: add
@@ -1242,7 +1257,7 @@ defm PMAX : SIMDBinaryFP<pmax, "pmax", 235>;
// Also match the pmin/pmax cases where the operands are int vectors (but the
// comparison is still a floating point comparison). This can happen when using
// the wasm_simd128.h intrinsics because v128_t is an integer vector.
-foreach vec = [F32x4, F64x2] in {
+foreach vec = [F32x4, F64x2, F16x8] in {
defvar pmin = !cast<NI>("PMIN_"#vec);
defvar pmax = !cast<NI>("PMAX_"#vec);
def : Pat<(vec.int_vt (vselect
@@ -1266,6 +1281,10 @@ def : Pat<(v2f64 (int_wasm_pmin (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMIN_F64x2 V128:$lhs, V128:$rhs)>;
def : Pat<(v2f64 (int_wasm_pmax (v2f64 V128:$lhs), (v2f64 V128:$rhs))),
(PMAX_F64x2 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmin (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+ (PMIN_F16x8 V128:$lhs, V128:$rhs)>;
+def : Pat<(v8f16 (int_wasm_pmax (v8f16 V128:$lhs), (v8f16 V128:$rhs))),
+ (PMAX_F16x8 V128:$lhs, V128:$rhs)>;
//===----------------------------------------------------------------------===//
// Conversions
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index d9d3f6be800fd..73ccea8d652db 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -35,3 +35,71 @@ define float @extract_lane_v8f16(<8 x half> %v) {
%r = call float @llvm.wasm.extract.lane.f16x8(<8 x half> %v, i32 1)
ret float %r
}
+
+; CHECK-LABEL: add_v8f16:
+; CHECK: f16x8.add $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @add_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fadd <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: sub_v8f16:
+; CHECK: f16x8.sub $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @sub_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fsub <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: mul_v8f16:
+; CHECK: f16x8.mul $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @mul_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fmul <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: div_v8f16:
+; CHECK: f16x8.div $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+define <8 x half> @div_v8f16(<8 x half> %a, <8 x half> %b) {
+ %r = fdiv <8 x half> %a, %b
+ ret <8 x half> %r
+}
+
+; CHECK-LABEL: min_intrinsic_v8f16:
+; CHECK: f16x8.min $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.minimum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @min_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+ %a = call <8 x half> @llvm.minimum.v8f16(<8 x half> %x, <8 x half> %y)
+ ret <8 x half> %a
+}
+
+; CHECK-LABEL: max_intrinsic_v8f16:
+; CHECK: f16x8.max $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.maximum.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @max_intrinsic_v8f16(<8 x half> %x, <8 x half> %y) {
+ %a = call <8 x half> @llvm.maximum.v8f16(<8 x half> %x, <8 x half> %y)
+ ret <8 x half> %a
+}
+
+; CHECK-LABEL: pmin_intrinsic_v8f16:
+; CHECK: f16x8.pmin $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.wasm.pmin.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmin_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+ %v = call <8 x half> @llvm.wasm.pmin.v8f16(<8 x half> %a, <8 x half> %b)
+ ret <8 x half> %v
+}
+
+; CHECK-LABEL: pmax_intrinsic_v8f16:
+; CHECK: f16x8.pmax $push0=, $0, $1
+; CHECK-NEXT: return $pop0
+declare <8 x half> @llvm.wasm.pmax.v8f16(<8 x half>, <8 x half>)
+define <8 x half> @pmax_intrinsic_v8f16(<8 x half> %a, <8 x half> %b) {
+ %v = call <8 x half> @llvm.wasm.pmax.v8f16(<8 x half> %a, <8 x half> %b)
+ ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index d397188a9882e..113a23da776fa 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -851,4 +851,28 @@ main:
# CHECK: f16x8.extract_lane 1 # encoding: [0xfd,0xa1,0x02,0x01]
f16x8.extract_lane 1
+ # CHECK: f16x8.add # encoding: [0xfd,0xb4,0x02]
+ f16x8.add
+
+ # CHECK: f16x8.sub # encoding: [0xfd,0xb5,0x02]
+ f16x8.sub
+
+ # CHECK: f16x8.mul # encoding: [0xfd,0xb6,0x02]
+ f16x8.mul
+
+ # CHECK: f16x8.div # encoding: [0xfd,0xb7,0x02]
+ f16x8.div
+
+ # CHECK: f16x8.min # encoding: [0xfd,0xb8,0x02]
+ f16x8.min
+
+ # CHECK: f16x8.max # encoding: [0xfd,0xb9,0x02]
+ f16x8.max
+
+ # CHECK: f16x8.pmin # encoding: [0xfd,0xba,0x02]
+ f16x8.pmin
+
+ # CHECK: f16x8.pmax # encoding: [0xfd,0xbb,0x02]
+ f16x8.pmax
+
end_function
>From 11da7bbce1a2df80825e8d14e603a07656209504 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Tue, 28 May 2024 22:35:21 +0000
Subject: [PATCH 2/2] Review comments.
---
llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 83260fbaa700b..baf15ccdbe9ed 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -165,6 +165,7 @@ def F16x8 : Vec {
let prefix = "f16x8";
}
+// TODO: Include F16x8 here when half precision is better supported.
defvar AllVecs = [I8x16, I16x8, I32x4, I64x2, F32x4, F64x2];
defvar IntVecs = [I8x16, I16x8, I32x4, I64x2];
@@ -804,6 +805,11 @@ multiclass SIMDBinary<Vec vec, SDPatternOperator node, string name,
vec.prefix#"."#name, simdop, reqs>;
}
+multiclass HalfPrecisionBinary<Vec vec, SDPatternOperator node, string name,
+ bits<32> simdop> {
+ defm "" : SIMDBinary<vec, node, name, simdop, [HasHalfPrecision]>;
+}
+
multiclass SIMDBitwise<SDPatternOperator node, string name, bits<32> simdop,
bit commutable = false> {
let isCommutable = commutable in
@@ -1213,7 +1219,7 @@ def : Pat<(v2f64 (froundeven (v2f64 V128:$src))), (NEAREST_F64x2 V128:$src)>;
multiclass SIMDBinaryFP<SDPatternOperator node, string name, bits<32> baseInst> {
defm "" : SIMDBinary<F32x4, node, name, baseInst>;
defm "" : SIMDBinary<F64x2, node, name, !add(baseInst, 12)>;
- defm "" : SIMDBinary<F16x8, node, name, !add(baseInst, 80), [HasHalfPrecision]>;
+ defm "" : HalfPrecisionBinary<F16x8, node, name, !add(baseInst, 80)>;
}
// Addition: add
More information about the llvm-commits
mailing list