[llvm] [SPIR-V] Fix return type when sampling an image with OpImageSampleExplicitLod (PR #89252)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 18 10:29:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR fixes parsing of builtins return types in general and fixes return type when sampling an image with OpImageSampleExplicitLod in particular.

---
Full diff: https://github.com/llvm/llvm-project/pull/89252.diff


4 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+4-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+1-1) 
- (added) llvm/test/CodeGen/SPIRV/image/sampler.ll (+32) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 950f9df28dd397..4b07d7e61fa113 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1634,7 +1634,10 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
       ReturnType = ReturnType.substr(ReturnType.find("_R") + 2);
       ReturnType = ReturnType.substr(0, ReturnType.find('('));
     }
-    SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
+    SPIRVType *Type =
+        Call->ReturnType
+            ? Call->ReturnType
+            : GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
     if (!Type) {
       std::string DiagMsg =
           "Unable to recognize SPIRV type name: " + ReturnType;
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index 2e44c208ed8e04..871b95a28068e6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -369,7 +369,7 @@ bool isEntryPoint(const Function &F) {
   return false;
 }
 
-Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx) {
+Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
   TypeName.consume_front("atomic_");
   if (TypeName.consume_front("void"))
     return Type::getVoidTy(Ctx);
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index cd1a2af09147e3..6a91b6e576f9c7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -100,7 +100,7 @@ bool isSpecialOpaqueType(const Type *Ty);
 bool isEntryPoint(const Function &F);
 
 // Parse basic scalar type name, substring TypeName, and return LLVM type.
-Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx);
+Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx);
 
 // True if this is an instance of TypedPointerType.
 inline bool isTypedPointerTy(const Type *T) {
diff --git a/llvm/test/CodeGen/SPIRV/image/sampler.ll b/llvm/test/CodeGen/SPIRV/image/sampler.ll
new file mode 100644
index 00000000000000..7b45c95f5ed433
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/image/sampler.ll
@@ -0,0 +1,32 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#i32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#v4i32:]] = OpTypeVector %[[#i32]] 4
+; CHECK-DAG: %[[#ptrv4i32:]] = OpTypePointer CrossWorkgroup %[[#v4i32]]
+; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#typesampled:]] = OpTypeSampledImage
+; CHECK-DAG: %[[#const0:]] = OpConstant %[[#float]] 0
+; CHECK: OpFunction
+; CHECK: OpFunctionParameter
+; CHECK: %[[#arg1:]] = OpFunctionParameter
+; CHECK: %[[#arg2:]] = OpFunctionParameter
+; CHECK: %[[#addr:]] = OpInBoundsPtrAccessChain
+; CHECK: %[[#img:]] = OpSampledImage %[[#typesampled:]] %[[#arg1]] %[[#arg2]]
+; CHECK: %[[#sample:]] = OpImageSampleExplicitLod %[[#v4i32]] %[[#img]] %[[#const0]] Lod %[[#const0]]
+; CHECK: %[[#casted:]] = OpBitcast %[[#ptrv4i32]] %[[#addr]]
+; CHECK: OpStore %[[#casted]] %[[#sample]] Aligned 16
+
+%"class.sycl::_V1::vec" = type { <4 x i32> }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) align 16 %_arg_acc, target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %_arg_img, target("spirv.Sampler") %_arg_sampler) {
+entry:
+  %data = getelementptr inbounds %"class.sycl::_V1::vec", ptr addrspace(1) %_arg_acc, i64 0
+  %img = tail call spir_func target("spirv.SampledImage", void, 0, 0, 0, 0, 0, 0, 0) @_Z20__spirv_SampledImage(target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0) %_arg_img, target("spirv.Sampler") %_arg_sampler)
+  %sample = tail call spir_func <4 x i32> @_Z30__spirv_ImageSampleExplicitLod(target("spirv.SampledImage", void, 0, 0, 0, 0, 0, 0, 0) %img, float 0.000000e+00, i32 2, float 0.000000e+00)
+  store <4 x i32> %sample, ptr addrspace(1) %data, align 16
+  ret void
+}
+
+declare dso_local spir_func target("spirv.SampledImage", void, 0, 0, 0, 0, 0, 0, 0) @_Z20__spirv_SampledImage(target("spirv.Image", void, 0, 0, 0, 0, 0, 0, 0), target("spirv.Sampler"))
+declare dso_local spir_func <4 x i32> @_Z30__spirv_ImageSampleExplicitLod(target("spirv.SampledImage", void, 0, 0, 0, 0, 0, 0, 0), float, i32, float)

``````````

</details>


https://github.com/llvm/llvm-project/pull/89252


More information about the llvm-commits mailing list