[llvm] Improve how lowering of formal arguments in SPIR-V Backend interprets a value of 'kernel_arg_type' (PR #78730)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 22 03:08:41 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/78730

>From 6f3fb5155c0b86a1386e6cbdc7e54ce95bfc6d5c Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 19 Jan 2024 06:58:49 -0800
Subject: [PATCH 1/4] improve how lowering of formal arguments interprets a
 value of 'kernel_arg_type'

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   | 35 ++++++++++---------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 12 ++++---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  1 +
 3 files changed, 27 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 0a8b5499a1fc2a..b6c151b1e73c85 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -209,23 +209,24 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
       isSpecialOpaqueType(OriginalArgType))
     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 
-  MDString *MDKernelArgType =
-      getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
-  if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
-                           !MDKernelArgType->getString().ends_with("_t")))
-    return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
-
-  if (MDKernelArgType->getString().ends_with("*"))
-    return GR->getOrCreateSPIRVTypeByName(
-        MDKernelArgType->getString(), MIRBuilder,
-        addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));
-
-  if (MDKernelArgType->getString().ends_with("_t"))
-    return GR->getOrCreateSPIRVTypeByName(
-        "opencl." + MDKernelArgType->getString().str(), MIRBuilder,
-        SPIRV::StorageClass::Function, ArgAccessQual);
-
-  llvm_unreachable("Unable to recognize argument type name.");
+  SPIRVType *ResArgType = nullptr;
+  if (MDString *MDKernelArgType =
+          getKernelArgAttribute(F, ArgIdx, "kernel_arg_type")) {
+    StringRef MDTypeStr = MDKernelArgType->getString();
+    if (MDTypeStr.ends_with("*")) {
+      ResArgType = GR->getOrCreateSPIRVTypeByName(
+          MDTypeStr, MIRBuilder,
+          addressSpaceToStorageClass(
+              OriginalArgType->getPointerAddressSpace()));
+    } else if (MDTypeStr.ends_with("_t")) {
+      ResArgType = GR->getOrCreateSPIRVTypeByName(
+          "opencl." + MDTypeStr.str(), MIRBuilder,
+          SPIRV::StorageClass::Function, ArgAccessQual);
+    }
+  }
+  return ResArgType ? ResArgType
+                    : GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
+                                               ArgAccessQual);
 }
 
 static bool isEntryPoint(const Function &F) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 6c009b9e8ddefa..f2c27467c34b49 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -443,8 +443,9 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
   SPIRVType *SampTy;
   if (SpvType)
     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
-  else
-    SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
+  else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
+                                                MIRBuilder)) == nullptr)
+    report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
 
   auto Sampler =
       ResReg.isValid()
@@ -941,6 +942,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
   return nullptr;
 }
 
+// Returns nullptr if unable to recognize SPIRV type name
 // TODO: maybe use tablegen to implement this.
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
     StringRef TypeStr, MachineIRBuilder &MIRBuilder,
@@ -992,8 +994,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
   } else if (TypeStr.starts_with("double")) {
     Ty = Type::getDoubleTy(Ctx);
     TypeStr = TypeStr.substr(strlen("double"));
-  } else
-    llvm_unreachable("Unable to recognize SPIRV type name.");
+  } else {
+    // Unable to recognize SPIRV type name
+    return nullptr;
+  }
 
   auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 60967bfb68a87e..f3280928c25dfa 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -138,6 +138,7 @@ class SPIRVGlobalRegistry {
 
   // Either generate a new OpTypeXXX instruction or return an existing one
   // corresponding to the given string containing the name of the builtin type.
+  // Return nullptr if unable to recognize SPIRV type name from `TypeStr`.
   SPIRVType *getOrCreateSPIRVTypeByName(
       StringRef TypeStr, MachineIRBuilder &MIRBuilder,
       SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,

>From ae1389ef50c8925996610eacd182d116070bcf9a Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 19 Jan 2024 08:33:04 -0800
Subject: [PATCH 2/4] add a test case

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  4 +++
 .../CodeGen/SPIRV/custom_kernel_arg_type.ll   | 33 +++++++++++++++++++
 2 files changed, 37 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c85bd27d256b2a..fd82212ae1b4de 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1386,6 +1386,10 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
       ReturnType = ReturnType.substr(0, ReturnType.find('('));
     }
     SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
+    if (!Type) {
+      std::string DiagMsg = "Unable to recognize SPIRV type name: " + ReturnType;
+      report_fatal_error(DiagMsg.c_str());
+    }
     MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
     MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
     MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
diff --git a/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll b/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll
new file mode 100644
index 00000000000000..a4971d064e7c76
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll
@@ -0,0 +1,33 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+
+; CHECK: %[[TyInt:.*]] = OpTypeInt 8 0
+; CHECK: %[[TyPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt]]
+; CHECK: OpFunctionParameter %[[TyPtr]]
+; CHECK: OpFunctionParameter %[[TyPtr]]
+
+%struct.my_kernel_data = type { i32, i32, i32, i32, i32 }
+%struct.my_struct = type { i32, i32 }
+
+define spir_kernel void @test(ptr addrspace(1) %in, ptr addrspace(1) %outData) !kernel_arg_type !5 {
+entry:
+  ret void
+}
+
+!llvm.module.flags = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!opencl.ocl.version = !{!1}
+!opencl.spir.version = !{!2}
+!opencl.used.extensions = !{!3}
+!opencl.used.optional.core.features = !{!3}
+!opencl.compiler.options = !{!3}
+!llvm.ident = !{!4}
+!opencl.kernels = !{!6}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 1, i32 0}
+!2 = !{i32 1, i32 2}
+!3 = !{}
+!4 = !{!"clang version 6.0.0"}
+!5 = !{!"my_kernel_data*", !"struct my_struct*"}
+!6 = !{ptr @test}
+

>From d0e04ba36d2a6fd9fa5812abb29278b61fa4abe2 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 19 Jan 2024 09:04:30 -0800
Subject: [PATCH 3/4] apply clang-format

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp     | 3 ++-
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp | 5 ++---
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index fd82212ae1b4de..e4593e7db90e8b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1387,7 +1387,8 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
     }
     SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
     if (!Type) {
-      std::string DiagMsg = "Unable to recognize SPIRV type name: " + ReturnType;
+      std::string DiagMsg =
+          "Unable to recognize SPIRV type name: " + ReturnType;
       report_fatal_error(DiagMsg.c_str());
     }
     MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index b6c151b1e73c85..66a6763b0f039a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -213,16 +213,15 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
   if (MDString *MDKernelArgType =
           getKernelArgAttribute(F, ArgIdx, "kernel_arg_type")) {
     StringRef MDTypeStr = MDKernelArgType->getString();
-    if (MDTypeStr.ends_with("*")) {
+    if (MDTypeStr.ends_with("*"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           MDTypeStr, MIRBuilder,
           addressSpaceToStorageClass(
               OriginalArgType->getPointerAddressSpace()));
-    } else if (MDTypeStr.ends_with("_t")) {
+    else if (MDTypeStr.ends_with("_t"))
       ResArgType = GR->getOrCreateSPIRVTypeByName(
           "opencl." + MDTypeStr.str(), MIRBuilder,
           SPIRV::StorageClass::Function, ArgAccessQual);
-    }
   }
   return ResArgType ? ResArgType
                     : GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,

>From dcca4e78db360b2d95a1eb7f3151cd9e8ddbb3df Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 22 Jan 2024 03:08:29 -0800
Subject: [PATCH 4/4] add a run line to the test with the SPIR-V validator

---
 llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll b/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll
index a4971d064e7c76..4593fad783c60e 100644
--- a/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll
+++ b/llvm/test/CodeGen/SPIRV/custom_kernel_arg_type.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK: %[[TyInt:.*]] = OpTypeInt 8 0
 ; CHECK: %[[TyPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt]]



More information about the llvm-commits mailing list