[clang] [llvm] [WebAssembly] Implement prototype f16x8.splat instruction. (PR #93228)
Brendan Dahl via llvm-commits
llvm-commits at lists.llvm.org
Thu May 23 11:41:54 PDT 2024
https://github.com/brendandahl updated https://github.com/llvm/llvm-project/pull/93228
>From 28cc678038feefffceba8cbe24349e1885b24c75 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Tue, 21 May 2024 21:15:14 +0000
Subject: [PATCH] [WebAssembly] Implement prototype f16x8.splat instruction.
Adds a builtin and intrinsic for the f16x8.splat instruction.
Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md
Note: the current spec has f16x8.splat as opcode 0x123, but this is incorrect and will be changed to 0x120 soon.
---
clang/include/clang/Basic/BuiltinsWebAssembly.def | 1 +
clang/lib/Basic/Targets/WebAssembly.h | 1 +
clang/lib/CodeGen/CGBuiltin.cpp | 5 +++++
clang/test/CodeGen/builtins-wasm.c | 6 ++++++
llvm/include/llvm/IR/IntrinsicsWebAssembly.td | 4 ++++
.../Utils/WebAssemblyTypeUtilities.cpp | 1 +
.../WebAssembly/WebAssemblyISelLowering.cpp | 3 +++
.../Target/WebAssembly/WebAssemblyInstrSIMD.td | 15 +++++++++++++++
.../Target/WebAssembly/WebAssemblyRegisterInfo.td | 5 +++--
llvm/test/CodeGen/WebAssembly/half-precision.ll | 12 ++++++++++--
llvm/test/MC/WebAssembly/simd-encodings.s | 3 +++
11 files changed, 52 insertions(+), 4 deletions(-)
diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index 8645cff1e8679..dbe79aa39190d 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -193,6 +193,7 @@ TARGET_BUILTIN(__builtin_wasm_relaxed_dot_bf16x8_add_f32_f32x4, "V4fV8UsV8UsV4f"
// Half-Precision (fp16)
TARGET_BUILTIN(__builtin_wasm_loadf16_f32, "fh*", "nU", "half-precision")
TARGET_BUILTIN(__builtin_wasm_storef16_f32, "vfh*", "n", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_splat_f16x8, "V8hf", "nc", "half-precision")
// Reference Types builtins
// Some builtins are custom type-checked - see 't' as part of the third argument,
diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h
index 4db97867df607..46416d516b42f 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -90,6 +90,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
StringRef getABI() const override;
bool setABI(const std::string &Name) override;
+ bool useFP16ConversionIntrinsics() const override { return false; }
protected:
void getTargetDefines(const LangOptions &Opts,
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ba94bf89e4751..91083c1cfae96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -21230,6 +21230,11 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_storef16_f32);
return Builder.CreateCall(Callee, {Val, Addr});
}
+ case WebAssembly::BI__builtin_wasm_splat_f16x8: {
+ Value *Val = EmitScalarExpr(E->getArg(0));
+ Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_splat_f16x8);
+ return Builder.CreateCall(Callee, {Val});
+ }
case WebAssembly::BI__builtin_wasm_table_get: {
assert(E->getArg(0)->getType()->isArrayType());
Value *Table = EmitArrayToPointerDecay(E->getArg(0)).emitRawPointer(*this);
diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c
index bcb15969de1c5..76c6305d422a2 100644
--- a/clang/test/CodeGen/builtins-wasm.c
+++ b/clang/test/CodeGen/builtins-wasm.c
@@ -11,6 +11,7 @@ typedef unsigned char u8x16 __attribute((vector_size(16)));
typedef unsigned short u16x8 __attribute((vector_size(16)));
typedef unsigned int u32x4 __attribute((vector_size(16)));
typedef unsigned long long u64x2 __attribute((vector_size(16)));
+typedef __fp16 f16x8 __attribute((vector_size(16)));
typedef float f32x4 __attribute((vector_size(16)));
typedef double f64x2 __attribute((vector_size(16)));
@@ -813,6 +814,11 @@ void store_f16_f32(float val, __fp16 *addr) {
// WEBASSEMBLY-NEXT: ret
}
+f16x8 splat_f16x8(float a) {
+ // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.splat.f16x8(float %a)
+ // WEBASSEMBLY-NEXT: ret <8 x half> %0
+ return __builtin_wasm_splat_f16x8(a);
+}
__externref_t externref_null() {
return __builtin_wasm_ref_null_extern();
// WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index 572d334ac9552..c950b33182689 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -337,6 +337,10 @@ def int_wasm_storef16_f32:
[llvm_float_ty, llvm_ptr_ty],
[IntrWriteMem, IntrArgMemOnly],
"", [SDNPMemOperand]>;
+def int_wasm_splat_f16x8:
+ DefaultAttrsIntrinsic<[llvm_v8f16_ty],
+ [llvm_float_ty],
+ [IntrNoMem, IntrSpeculatable]>;
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp b/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
index fac2e0d935f5a..867953b4e8d71 100644
--- a/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
+++ b/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
@@ -50,6 +50,7 @@ wasm::ValType WebAssembly::toValType(MVT Type) {
case MVT::v8i16:
case MVT::v4i32:
case MVT::v2i64:
+ case MVT::v8f16:
case MVT::v4f32:
case MVT::v2f64:
return wasm::ValType::V128;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 527bb4c9fbea6..b0b2a9e55ae44 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -70,6 +70,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
addRegisterClass(MVT::v2i64, &WebAssembly::V128RegClass);
addRegisterClass(MVT::v2f64, &WebAssembly::V128RegClass);
}
+ if (Subtarget->hasHalfPrecision()) {
+ addRegisterClass(MVT::v8f16, &WebAssembly::V128RegClass);
+ }
if (Subtarget->hasReferenceTypes()) {
addRegisterClass(MVT::externref, &WebAssembly::EXTERNREFRegClass);
addRegisterClass(MVT::funcref, &WebAssembly::FUNCREFRegClass);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index af95dfa25a189..bb898e7bebd3a 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -38,6 +38,13 @@ multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
asmstr_s, simdop, HasRelaxedSIMD>;
}
+multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
+ list<dag> pattern_r, string asmstr_r = "",
+ string asmstr_s = "", bits<32> simdop = -1> {
+ defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
+ asmstr_s, simdop, HasHalfPrecision>;
+}
+
defm "" : ARGUMENT<V128, v16i8>;
defm "" : ARGUMENT<V128, v8i16>;
@@ -591,6 +598,14 @@ defm "" : Splat<I64x2, 18>;
defm "" : Splat<F32x4, 19>;
defm "" : Splat<F64x2, 20>;
+// Half values are not fully supported so an intrinsic is used instead of a
+// regular Splat pattern as above.
+defm SPLAT_F16x8 :
+ HALF_PRECISION_I<(outs V128:$dst), (ins F32:$x),
+ (outs), (ins),
+ [(set (v8f16 V128:$dst), (int_wasm_splat_f16x8 F32:$x))],
+ "f16x8.splat\t$dst, $x", "f16x8.splat", 0x120>;
+
// scalar_to_vector leaves high lanes undefined, so can be a splat
foreach vec = AllVecs in
def : Pat<(vec.vt (scalar_to_vector (vec.lane_vt vec.lane_rc:$x))),
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
index ba2936b492a9a..4e2faa608be07 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
@@ -63,7 +63,8 @@ def I32 : WebAssemblyRegClass<[i32], 32, (add FP32, SP32, I32_0)>;
def I64 : WebAssemblyRegClass<[i64], 64, (add FP64, SP64, I64_0)>;
def F32 : WebAssemblyRegClass<[f32], 32, (add F32_0)>;
def F64 : WebAssemblyRegClass<[f64], 64, (add F64_0)>;
-def V128 : WebAssemblyRegClass<[v4f32, v2f64, v2i64, v4i32, v16i8, v8i16], 128,
- (add V128_0)>;
+def V128 : WebAssemblyRegClass<[v8f16, v4f32, v2f64, v2i64, v4i32, v16i8,
+ v8i16],
+ 128, (add V128_0)>;
def FUNCREF : WebAssemblyRegClass<[funcref], 0, (add FUNCREF_0)>;
def EXTERNREF : WebAssemblyRegClass<[externref], 0, (add EXTERNREF_0)>;
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index 89e9c42637c14..eee5bf8b8c48a 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -1,5 +1,5 @@
-; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
-; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
+; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
+; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
declare float @llvm.wasm.loadf32.f16(ptr)
declare void @llvm.wasm.storef16.f32(float, ptr)
@@ -19,3 +19,11 @@ define void @stf16_32(float %v, ptr %p) {
tail call void @llvm.wasm.storef16.f32(float %v, ptr %p)
ret void
}
+
+; CHECK-LABEL: splat_v8f16:
+; CHECK: f16x8.splat $push0=, $0
+; CHECK-NEXT: return $pop0
+define <8 x half> @splat_v8f16(float %x) {
+ %v = call <8 x half> @llvm.wasm.splat.f16x8(float %x)
+ ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index 57fa71e74b8d7..c23a9d1958099 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -845,4 +845,7 @@ main:
# CHECK: f32.store_f16 32 # encoding: [0xfc,0x31,0x01,0x20]
f32.store_f16 32
+ # CHECK: f16x8.splat # encoding: [0xfd,0xa0,0x02]
+ f16x8.splat
+
end_function
More information about the llvm-commits
mailing list