[llvm] [SPIR-V] Add Float16 support when targeting Vulkan (PR #77115)

Natalie Chouinard via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 10 12:45:19 PST 2024


https://github.com/sudonatalie updated https://github.com/llvm/llvm-project/pull/77115

>From d43986ca21e7c8fdb9a75c9b02a8472769d49258 Mon Sep 17 00:00:00 2001
From: Natalie Chouinard <chouinard at google.com>
Date: Fri, 5 Jan 2024 16:25:57 +0000
Subject: [PATCH 1/4] [SPIR-V] Add Float16 support when targeting Vulkan

Add Float16 to Vulkan's available capabilities, and guard Float16Buffer
(Kernel-only capability) against being added outside OpenCL
environments.

Add tests to verify half and half vector types, and validate with
spirv-val.
---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 10 +++++----
 llvm/test/CodeGen/SPIRV/basic_float_types.ll  | 22 +++++++++++++++++++
 2 files changed, 28 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2a830535a2aa13..2222241d51b2e6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -597,8 +597,9 @@ void RequirementHandler::initAvailableCapabilitiesForVulkan(
     const SPIRVSubtarget &ST) {
   addAvailableCaps({Capability::Shader, Capability::Linkage});
 
-  // Provided by Vulkan version 1.0.
-  addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float64});
+  // Provided by all supported Vulkan versions.
+  addAvailableCaps({Capability::Int16, Capability::Int64, Capability::Float16,
+                    Capability::Float64});
 }
 
 } // namespace SPIRV
@@ -733,11 +734,12 @@ void addInstrRequirements(const MachineInstr &MI,
     auto SC = MI.getOperand(1).getImm();
     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
                                ST);
-    // If it's a type of pointer to float16, add Float16Buffer capability.
+    // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
+    // capability.
     assert(MI.getOperand(2).isReg());
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
-    if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+    if (ST.isOpenCLEnv() && TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
         TypeDef->getOperand(1).getImm() == 16)
       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
     break;
diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index 4287adc85cfd84..8d2641b6a34f0e 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -1,12 +1,18 @@
 ; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 define void @main() {
 entry:
+; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16
 ; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:   %[[#double:]] = OpTypeFloat 64
 
+; CHECK-DAG:   %[[#v2half:]] = OpTypeVector %[[#half]] 2
+; CHECK-DAG:   %[[#v3half:]] = OpTypeVector %[[#half]] 3
+; CHECK-DAG:   %[[#v4half:]] = OpTypeVector %[[#half]] 4
+
 ; CHECK-DAG:  %[[#v2float:]] = OpTypeVector %[[#float]] 2
 ; CHECK-DAG:  %[[#v3float:]] = OpTypeVector %[[#float]] 3
 ; CHECK-DAG:  %[[#v4float:]] = OpTypeVector %[[#float]] 4
@@ -15,8 +21,12 @@ entry:
 ; CHECK-DAG: %[[#v3double:]] = OpTypeVector %[[#double]] 3
 ; CHECK-DAG: %[[#v4double:]] = OpTypeVector %[[#double]] 4
 
+; CHECK-DAG:     %[[#ptr_Function_half:]] = OpTypePointer Function %[[#half]]
 ; CHECK-DAG:    %[[#ptr_Function_float:]] = OpTypePointer Function %[[#float]]
 ; CHECK-DAG:   %[[#ptr_Function_double:]] = OpTypePointer Function %[[#double]]
+; CHECK-DAG:   %[[#ptr_Function_v2half:]] = OpTypePointer Function %[[#v2half]]
+; CHECK-DAG:   %[[#ptr_Function_v3half:]] = OpTypePointer Function %[[#v3half]]
+; CHECK-DAG:   %[[#ptr_Function_v4half:]] = OpTypePointer Function %[[#v4half]]
 ; CHECK-DAG:  %[[#ptr_Function_v2float:]] = OpTypePointer Function %[[#v2float]]
 ; CHECK-DAG:  %[[#ptr_Function_v3float:]] = OpTypePointer Function %[[#v3float]]
 ; CHECK-DAG:  %[[#ptr_Function_v4float:]] = OpTypePointer Function %[[#v4float]]
@@ -24,12 +34,24 @@ entry:
 ; CHECK-DAG: %[[#ptr_Function_v3double:]] = OpTypePointer Function %[[#v3double]]
 ; CHECK-DAG: %[[#ptr_Function_v4double:]] = OpTypePointer Function %[[#v4double]]
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_half]] Function
+  %half_Val = alloca half, align 2
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_float]] Function
   %float_Val = alloca float, align 4
 
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_double]] Function
   %double_Val = alloca double, align 8
 
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2half]] Function
+  %half2_Val = alloca <2 x half>, align 4
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v3half]] Function
+  %half3_Val = alloca <3 x half>, align 8
+
+; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v4half]] Function
+  %half4_Val = alloca <4 x half>, align 8
+
 ; CHECK: %[[#]] = OpVariable %[[#ptr_Function_v2float]] Function
   %float2_Val = alloca <2 x float>, align 8
 

>From abcf7d6bda31ff4b34e3d3a1421b91beba5b66ea Mon Sep 17 00:00:00 2001
From: Natalie Chouinard <chouinard at google.com>
Date: Tue, 9 Jan 2024 15:02:37 +0000
Subject: [PATCH 2/4] Add OpCapability CHECKs

---
 llvm/test/CodeGen/SPIRV/basic_float_types.ll | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/llvm/test/CodeGen/SPIRV/basic_float_types.ll b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
index 8d2641b6a34f0e..1c7a8a851f59c6 100644
--- a/llvm/test/CodeGen/SPIRV/basic_float_types.ll
+++ b/llvm/test/CodeGen/SPIRV/basic_float_types.ll
@@ -5,6 +5,10 @@
 
 define void @main() {
 entry:
+
+; CHECK-DAG: OpCapability Float16
+; CHECK-DAG: OpCapability Float64
+
 ; CHECK-DAG:     %[[#half:]] = OpTypeFloat 16
 ; CHECK-DAG:    %[[#float:]] = OpTypeFloat 32
 ; CHECK-DAG:   %[[#double:]] = OpTypeFloat 64

>From 48f08c08dca9eced35a055d3152bbaeea9e4b6fd Mon Sep 17 00:00:00 2001
From: Natalie Chouinard <chouinard at google.com>
Date: Wed, 10 Jan 2024 17:02:56 +0000
Subject: [PATCH 3/4] Break earlier for non-OpenCL

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 2222241d51b2e6..eb60b677acaa11 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -731,6 +731,8 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypePointer: {
+    if (!ST.isOpenCLEnv())
+      break;
     auto SC = MI.getOperand(1).getImm();
     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
                                ST);
@@ -739,7 +741,7 @@ void addInstrRequirements(const MachineInstr &MI,
     assert(MI.getOperand(2).isReg());
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
-    if (ST.isOpenCLEnv() && TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
+    if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
         TypeDef->getOperand(1).getImm() == 16)
       Reqs.addCapability(SPIRV::Capability::Float16Buffer);
     break;

>From 1f0cb859a1d4a1a5b935abe2ddc4d70e683baacc Mon Sep 17 00:00:00 2001
From: Natalie Chouinard <chouinard at google.com>
Date: Wed, 10 Jan 2024 20:44:20 +0000
Subject: [PATCH 4/4] But not that early

---
 llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index eb60b677acaa11..027e4c2b46d873 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -731,13 +731,13 @@ void addInstrRequirements(const MachineInstr &MI,
     break;
   }
   case SPIRV::OpTypePointer: {
-    if (!ST.isOpenCLEnv())
-      break;
     auto SC = MI.getOperand(1).getImm();
     Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
                                ST);
     // If it's a type of pointer to float16 targeting OpenCL, add Float16Buffer
     // capability.
+    if (!ST.isOpenCLEnv())
+      break;
     assert(MI.getOperand(2).isReg());
     const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());



More information about the llvm-commits mailing list