[clang] [llvm] [WebAssembly] Implement prototype f16x8.splat instruction. (PR #93228)

Brendan Dahl via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 11:41:54 PDT 2024


https://github.com/brendandahl updated https://github.com/llvm/llvm-project/pull/93228

>From 28cc678038feefffceba8cbe24349e1885b24c75 Mon Sep 17 00:00:00 2001
From: Brendan Dahl <brendan.dahl at gmail.com>
Date: Tue, 21 May 2024 21:15:14 +0000
Subject: [PATCH] [WebAssembly] Implement prototype f16x8.splat instruction.

Adds a builtin and intrinsic for the f16x8.splat instruction.

Specified at:
https://github.com/WebAssembly/half-precision/blob/29a9b9462c9285d4ccc1a5dc39214ddfd1892658/proposals/half-precision/Overview.md

Note: the current spec has f16x8.splat as opcode 0x123, but this is incorrect and will be changed to 0x120 soon.
---
 clang/include/clang/Basic/BuiltinsWebAssembly.def |  1 +
 clang/lib/Basic/Targets/WebAssembly.h             |  1 +
 clang/lib/CodeGen/CGBuiltin.cpp                   |  5 +++++
 clang/test/CodeGen/builtins-wasm.c                |  6 ++++++
 llvm/include/llvm/IR/IntrinsicsWebAssembly.td     |  4 ++++
 .../Utils/WebAssemblyTypeUtilities.cpp            |  1 +
 .../WebAssembly/WebAssemblyISelLowering.cpp       |  3 +++
 .../Target/WebAssembly/WebAssemblyInstrSIMD.td    | 15 +++++++++++++++
 .../Target/WebAssembly/WebAssemblyRegisterInfo.td |  5 +++--
 llvm/test/CodeGen/WebAssembly/half-precision.ll   | 12 ++++++++++--
 llvm/test/MC/WebAssembly/simd-encodings.s         |  3 +++
 11 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/clang/include/clang/Basic/BuiltinsWebAssembly.def b/clang/include/clang/Basic/BuiltinsWebAssembly.def
index 8645cff1e8679..dbe79aa39190d 100644
--- a/clang/include/clang/Basic/BuiltinsWebAssembly.def
+++ b/clang/include/clang/Basic/BuiltinsWebAssembly.def
@@ -193,6 +193,7 @@ TARGET_BUILTIN(__builtin_wasm_relaxed_dot_bf16x8_add_f32_f32x4, "V4fV8UsV8UsV4f"
 // Half-Precision (fp16)
 TARGET_BUILTIN(__builtin_wasm_loadf16_f32, "fh*", "nU", "half-precision")
 TARGET_BUILTIN(__builtin_wasm_storef16_f32, "vfh*", "n", "half-precision")
+TARGET_BUILTIN(__builtin_wasm_splat_f16x8, "V8hf", "nc", "half-precision")
 
 // Reference Types builtins
 // Some builtins are custom type-checked - see 't' as part of the third argument,
diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h
index 4db97867df607..46416d516b42f 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -90,6 +90,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo {
 
   StringRef getABI() const override;
   bool setABI(const std::string &Name) override;
+  bool useFP16ConversionIntrinsics() const override { return false; }
 
 protected:
   void getTargetDefines(const LangOptions &Opts,
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ba94bf89e4751..91083c1cfae96 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -21230,6 +21230,11 @@ Value *CodeGenFunction::EmitWebAssemblyBuiltinExpr(unsigned BuiltinID,
     Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_storef16_f32);
     return Builder.CreateCall(Callee, {Val, Addr});
   }
+  case WebAssembly::BI__builtin_wasm_splat_f16x8: {
+    Value *Val = EmitScalarExpr(E->getArg(0));
+    Function *Callee = CGM.getIntrinsic(Intrinsic::wasm_splat_f16x8);
+    return Builder.CreateCall(Callee, {Val});
+  }
   case WebAssembly::BI__builtin_wasm_table_get: {
     assert(E->getArg(0)->getType()->isArrayType());
     Value *Table = EmitArrayToPointerDecay(E->getArg(0)).emitRawPointer(*this);
diff --git a/clang/test/CodeGen/builtins-wasm.c b/clang/test/CodeGen/builtins-wasm.c
index bcb15969de1c5..76c6305d422a2 100644
--- a/clang/test/CodeGen/builtins-wasm.c
+++ b/clang/test/CodeGen/builtins-wasm.c
@@ -11,6 +11,7 @@ typedef unsigned char u8x16 __attribute((vector_size(16)));
 typedef unsigned short u16x8 __attribute((vector_size(16)));
 typedef unsigned int u32x4 __attribute((vector_size(16)));
 typedef unsigned long long u64x2 __attribute((vector_size(16)));
+typedef __fp16 f16x8 __attribute((vector_size(16)));
 typedef float f32x4 __attribute((vector_size(16)));
 typedef double f64x2 __attribute((vector_size(16)));
 
@@ -813,6 +814,11 @@ void store_f16_f32(float val, __fp16 *addr) {
   // WEBASSEMBLY-NEXT: ret
 }
 
+f16x8 splat_f16x8(float a) {
+  // WEBASSEMBLY: %0 = tail call <8 x half> @llvm.wasm.splat.f16x8(float %a)
+  // WEBASSEMBLY-NEXT: ret <8 x half> %0
+  return __builtin_wasm_splat_f16x8(a);
+}
 __externref_t externref_null() {
   return __builtin_wasm_ref_null_extern();
   // WEBASSEMBLY: tail call ptr addrspace(10) @llvm.wasm.ref.null.extern()
diff --git a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
index 572d334ac9552..c950b33182689 100644
--- a/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
+++ b/llvm/include/llvm/IR/IntrinsicsWebAssembly.td
@@ -337,6 +337,10 @@ def int_wasm_storef16_f32:
             [llvm_float_ty, llvm_ptr_ty],
             [IntrWriteMem, IntrArgMemOnly],
              "", [SDNPMemOperand]>;
+def int_wasm_splat_f16x8:
+  DefaultAttrsIntrinsic<[llvm_v8f16_ty],
+                        [llvm_float_ty],
+                        [IntrNoMem, IntrSpeculatable]>;
 
 
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp b/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
index fac2e0d935f5a..867953b4e8d71 100644
--- a/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
+++ b/llvm/lib/Target/WebAssembly/Utils/WebAssemblyTypeUtilities.cpp
@@ -50,6 +50,7 @@ wasm::ValType WebAssembly::toValType(MVT Type) {
   case MVT::v8i16:
   case MVT::v4i32:
   case MVT::v2i64:
+  case MVT::v8f16:
   case MVT::v4f32:
   case MVT::v2f64:
     return wasm::ValType::V128;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 527bb4c9fbea6..b0b2a9e55ae44 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -70,6 +70,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     addRegisterClass(MVT::v2i64, &WebAssembly::V128RegClass);
     addRegisterClass(MVT::v2f64, &WebAssembly::V128RegClass);
   }
+  if (Subtarget->hasHalfPrecision()) {
+    addRegisterClass(MVT::v8f16, &WebAssembly::V128RegClass);
+  }
   if (Subtarget->hasReferenceTypes()) {
     addRegisterClass(MVT::externref, &WebAssembly::EXTERNREFRegClass);
     addRegisterClass(MVT::funcref, &WebAssembly::FUNCREFRegClass);
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index af95dfa25a189..bb898e7bebd3a 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -38,6 +38,13 @@ multiclass RELAXED_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
                             asmstr_s, simdop, HasRelaxedSIMD>;
 }
 
+multiclass HALF_PRECISION_I<dag oops_r, dag iops_r, dag oops_s, dag iops_s,
+                            list<dag> pattern_r, string asmstr_r = "",
+                            string asmstr_s = "", bits<32> simdop = -1> {
+  defm "" : ABSTRACT_SIMD_I<oops_r, iops_r, oops_s, iops_s, pattern_r, asmstr_r,
+                            asmstr_s, simdop, HasHalfPrecision>;
+}
+
 
 defm "" : ARGUMENT<V128, v16i8>;
 defm "" : ARGUMENT<V128, v8i16>;
@@ -591,6 +598,14 @@ defm "" : Splat<I64x2, 18>;
 defm "" : Splat<F32x4, 19>;
 defm "" : Splat<F64x2, 20>;
 
+// Half values are not fully supported so an intrinsic is used instead of a
+// regular Splat pattern as above.
+defm SPLAT_F16x8 :
+  HALF_PRECISION_I<(outs V128:$dst), (ins F32:$x),
+                   (outs), (ins),
+                   [(set (v8f16 V128:$dst), (int_wasm_splat_f16x8 F32:$x))],
+                   "f16x8.splat\t$dst, $x", "f16x8.splat", 0x120>;
+
 // scalar_to_vector leaves high lanes undefined, so can be a splat
 foreach vec = AllVecs in
 def : Pat<(vec.vt (scalar_to_vector (vec.lane_vt vec.lane_rc:$x))),
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
index ba2936b492a9a..4e2faa608be07 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterInfo.td
@@ -63,7 +63,8 @@ def I32 : WebAssemblyRegClass<[i32], 32, (add FP32, SP32, I32_0)>;
 def I64 : WebAssemblyRegClass<[i64], 64, (add FP64, SP64, I64_0)>;
 def F32 : WebAssemblyRegClass<[f32], 32, (add F32_0)>;
 def F64 : WebAssemblyRegClass<[f64], 64, (add F64_0)>;
-def V128 : WebAssemblyRegClass<[v4f32, v2f64, v2i64, v4i32, v16i8, v8i16], 128,
-                               (add V128_0)>;
+def V128 : WebAssemblyRegClass<[v8f16, v4f32, v2f64, v2i64, v4i32, v16i8,
+                                v8i16],
+                               128, (add V128_0)>;
 def FUNCREF : WebAssemblyRegClass<[funcref], 0, (add FUNCREF_0)>;
 def EXTERNREF : WebAssemblyRegClass<[externref], 0, (add EXTERNREF_0)>;
diff --git a/llvm/test/CodeGen/WebAssembly/half-precision.ll b/llvm/test/CodeGen/WebAssembly/half-precision.ll
index 89e9c42637c14..eee5bf8b8c48a 100644
--- a/llvm/test/CodeGen/WebAssembly/half-precision.ll
+++ b/llvm/test/CodeGen/WebAssembly/half-precision.ll
@@ -1,5 +1,5 @@
-; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
-; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision | FileCheck %s
+; RUN: llc < %s --mtriple=wasm32-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
+; RUN: llc < %s --mtriple=wasm64-unknown-unknown -asm-verbose=false -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+half-precision,+simd128 | FileCheck %s
 
 declare float @llvm.wasm.loadf32.f16(ptr)
 declare void @llvm.wasm.storef16.f32(float, ptr)
@@ -19,3 +19,11 @@ define void @stf16_32(float %v, ptr %p) {
   tail call void @llvm.wasm.storef16.f32(float %v, ptr %p)
   ret void
 }
+
+; CHECK-LABEL: splat_v8f16:
+; CHECK:       f16x8.splat $push0=, $0
+; CHECK-NEXT:  return $pop0
+define <8 x half> @splat_v8f16(float %x) {
+  %v = call <8 x half> @llvm.wasm.splat.f16x8(float %x)
+  ret <8 x half> %v
+}
diff --git a/llvm/test/MC/WebAssembly/simd-encodings.s b/llvm/test/MC/WebAssembly/simd-encodings.s
index 57fa71e74b8d7..c23a9d1958099 100644
--- a/llvm/test/MC/WebAssembly/simd-encodings.s
+++ b/llvm/test/MC/WebAssembly/simd-encodings.s
@@ -845,4 +845,7 @@ main:
     # CHECK: f32.store_f16 32 # encoding: [0xfc,0x31,0x01,0x20]
     f32.store_f16 32
 
+    # CHECK: f16x8.splat # encoding: [0xfd,0xa0,0x02]
+    f16x8.splat
+
     end_function



More information about the llvm-commits mailing list