[llvm] 03203b7 - [SPIR-V] Fix vloadn OpenCL builtin lowering (#81148)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 20 20:04:07 PST 2024


Author: Michal Paszkowski
Date: 2024-02-20T20:04:04-08:00
New Revision: 03203b79c6247465850ee6e9f3e2399afc35720b

URL: https://github.com/llvm/llvm-project/commit/03203b79c6247465850ee6e9f3e2399afc35720b
DIFF: https://github.com/llvm/llvm-project/commit/03203b79c6247465850ee6e9f3e2399afc35720b.diff

LOG: [SPIR-V] Fix vloadn OpenCL builtin lowering (#81148)

This pull request fixes an issue with missing vector element count
immediate in OpExtInst calls and adds a case for generating bitcasts
before GEPs for kernel arguments of non-matching pointer type. The new
LITs are based on basic/vload_local and basic/vload_global OpenCL CTS
tests. The tests after this change pass SPIR-V validation.

Added: 
    llvm/test/CodeGen/SPIRV/opencl/vload2.ll
    llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
    llvm/lib/Target/SPIRV/SPIRVBuiltins.td
    llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Removed: 
    llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 8a354dd04640b0..c1bb27322443ff 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -141,6 +141,7 @@ struct VectorLoadStoreBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
   uint32_t Number;
+  uint32_t ElementCount;
   bool IsRounded;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
@@ -2042,6 +2043,7 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
+  MIB.addImm(Builtin->ElementCount);
 
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".

diff  --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 571cfcfd6e7e5c..e6e3560d02f58b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1236,18 +1236,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
   string Name = name;
   InstructionSet Set = set;
   bits<32> Number = number;
+  bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
+                                !not(!eq(!find(name, "3"), -1)) : 3,
+                                !not(!eq(!find(name, "4"), -1)) : 4,
+                                !not(!eq(!find(name, "8"), -1)) : 8,
+                                !not(!eq(!find(name, "16"), -1)) : 16,
+                                true : 1);
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
-                                  !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
-                                  !not(!eq(!find(name, "_rtp"), -1)) : RTP,
-                                  !not(!eq(!find(name, "_rtn"), -1)) : RTN,
-                                  true : RTE);
+                                      !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
+                                      !not(!eq(!find(name, "_rtp"), -1)) : RTP,
+                                      !not(!eq(!find(name, "_rtn"), -1)) : RTN,
+                                      true : RTE);
 }
 
 // Table gathering all the vector data load/store builtins.
 def VectorLoadStoreBuiltins : GenericTable {
   let FilterClass = "VectorLoadStoreBuiltin";
-  let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
+  let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 26a5d7a30f19dd..e32cd50be56e38 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -290,25 +290,14 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
   Value *Pointer;
   Type *ExpectedElementType;
   unsigned OperandToReplace;
-  bool AllowCastingToChar = false;
 
   StoreInst *SI = dyn_cast<StoreInst>(I);
   if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
       SI->getValueOperand()->getType()->isPointerTy() &&
       isa<Argument>(SI->getValueOperand())) {
-    Argument *Arg = cast<Argument>(SI->getValueOperand());
-    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
-    if (!ArgType || ArgType->getString().starts_with("uchar*"))
-      return;
-
-    // Handle special case when StoreInst's value operand is a kernel argument
-    // of a pointer type. Since these arguments could have either a basic
-    // element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
-    // the StoreInst's value operand to default pointer element type (i8).
-    Pointer = Arg;
+    Pointer = SI->getValueOperand();
     ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
     OperandToReplace = 0;
-    AllowCastingToChar = true;
   } else if (SI) {
     Pointer = SI->getPointerOperand();
     ExpectedElementType = SI->getValueOperand()->getType();
@@ -390,10 +379,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
   }
 
   // Do not emit spv_ptrcast if it would cast to the default pointer element
-  // type (i8) of the same address space.
-  if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
+  // type (i8) of the same address space. In case of OpenCL kernels, make sure
+  // i8 is the pointer element type defined for the given kernel argument.
+  if (ExpectedElementType->isIntegerTy(8) &&
+      F->getCallingConv() != CallingConv::SPIR_KERNEL)
     return;
 
+  Argument *Arg = dyn_cast<Argument>(Pointer);
+  if (ExpectedElementType->isIntegerTy(8) &&
+      F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
+    MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
+    if (ArgType && ArgType->getString().starts_with("uchar*"))
+      return;
+  }
+
   // If this would be the first spv_ptrcast, the pointer's defining instruction
   // requires spv_assign_ptr_type and does not already have one, do not emit
   // spv_ptrcast and emit spv_assign_ptr_type instead.

diff  --git a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll b/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
deleted file mode 100644
index 40f1d59e4365e1..00000000000000
--- a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
+++ /dev/null
@@ -1,95 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-
-; TODO(#60133): Requires updates following opaque pointer migration.
-; XFAIL: *
-
-; CHECK: %[[#i16_ty:]] = OpTypeInt 16 0
-; CHECK: %[[#v4xi16_ty:]] = OpTypeVector %[[#i16_ty]] 4
-; CHECK: %[[#pv4xi16_ty:]] = OpTypePointer Function %[[#v4xi16_ty]]
-; CHECK: %[[#i16_const0:]] = OpConstant %[[#i16_ty]] 0
-; CHECK: %[[#i16_undef:]] = OpUndef %[[#i16_ty]]
-; CHECK: %[[#comp_const:]] = OpConstantComposite %[[#v4xi16_ty]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_undef]]
-
-; CHECK: %[[#r:]] = OpInBoundsPtrAccessChain
-; CHECK: %[[#r2:]] = OpBitcast %[[#pv4xi16_ty]] %[[#r]]
-; CHECK: OpStore %[[#r2]] %[[#comp_const]] Aligned 8
-
-define spir_kernel void @test_fn(i16 addrspace(1)* %srcValues, i32 addrspace(1)* %offsets, <3 x i16> addrspace(1)* %destBuffer, i32 %alignmentOffset) {
-entry:
-  %sPrivateStorage = alloca [42 x <3 x i16>], align 8
-  %0 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %1 = bitcast i8* %0 to i8*
-  call void @llvm.lifetime.start.p0i8(i64 336, i8* %1)
-  %2 = call spir_func <3 x i64> @BuiltInGlobalInvocationId()
-  %call = extractelement <3 x i64> %2, i32 0
-  %conv = trunc i64 %call to i32
-  %idxprom = sext i32 %conv to i64
-  %arrayidx = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 %idxprom
-  %storetmp = bitcast <3 x i16>* %arrayidx to <4 x i16>*
-  store <4 x i16> <i16 0, i16 0, i16 0, i16 undef>, <4 x i16>* %storetmp, align 8
-  %conv1 = sext i32 %conv to i64
-  %call2 = call spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64 %conv1, i16 addrspace(1)* %srcValues, i32 3)
-  %idxprom3 = sext i32 %conv to i64
-  %arrayidx4 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom3
-  %3 = load i32, i32 addrspace(1)* %arrayidx4, align 4
-  %conv5 = zext i32 %3 to i64
-  %arraydecay = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %4 = bitcast <3 x i16>* %arraydecay to i16*
-  %idx.ext = zext i32 %alignmentOffset to i64
-  %add.ptr = getelementptr inbounds i16, i16* %4, i64 %idx.ext
-  call spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16> %call2, i64 %conv5, i16* %add.ptr)
-  %arraydecay6 = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %5 = bitcast <3 x i16>* %arraydecay6 to i16*
-  %idxprom7 = sext i32 %conv to i64
-  %arrayidx8 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom7
-  %6 = load i32, i32 addrspace(1)* %arrayidx8, align 4
-  %mul = mul i32 3, %6
-  %idx.ext9 = zext i32 %mul to i64
-  %add.ptr10 = getelementptr inbounds i16, i16* %5, i64 %idx.ext9
-  %idx.ext11 = zext i32 %alignmentOffset to i64
-  %add.ptr12 = getelementptr inbounds i16, i16* %add.ptr10, i64 %idx.ext11
-  %7 = bitcast <3 x i16> addrspace(1)* %destBuffer to i16 addrspace(1)*
-  %idxprom13 = sext i32 %conv to i64
-  %arrayidx14 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom13
-  %8 = load i32, i32 addrspace(1)* %arrayidx14, align 4
-  %mul15 = mul i32 3, %8
-  %idx.ext16 = zext i32 %mul15 to i64
-  %add.ptr17 = getelementptr inbounds i16, i16 addrspace(1)* %7, i64 %idx.ext16
-  %idx.ext18 = zext i32 %alignmentOffset to i64
-  %add.ptr19 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr17, i64 %idx.ext18
-  br label %for.cond
-
-for.cond:                                         ; preds = %for.inc, %entry
-  %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]
-  %cmp = icmp ult i32 %i.0, 3
-  br i1 %cmp, label %for.body, label %for.end
-
-for.body:                                         ; preds = %for.cond
-  %idxprom21 = zext i32 %i.0 to i64
-  %arrayidx22 = getelementptr inbounds i16, i16* %add.ptr12, i64 %idxprom21
-  %9 = load i16, i16* %arrayidx22, align 2
-  %idxprom23 = zext i32 %i.0 to i64
-  %arrayidx24 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr19, i64 %idxprom23
-  store i16 %9, i16 addrspace(1)* %arrayidx24, align 2
-  br label %for.inc
-
-for.inc:                                          ; preds = %for.body
-  %inc = add i32 %i.0, 1
-  br label %for.cond
-
-for.end:                                          ; preds = %for.cond
-  %10 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %11 = bitcast i8* %10 to i8*
-  call void @llvm.lifetime.end.p0i8(i64 336, i8* %11)
-  ret void
-}
-
-declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64, i16 addrspace(1)*, i32)
-
-declare spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16>, i64, i16*)
-
-declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i64> @BuiltInGlobalInvocationId()

diff  --git a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
new file mode 100644
index 00000000000000..b219aebc29befe
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; This test only intends to check the vloadn builtin name resolution.
+; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+
+; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
+; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
+; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
+; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
+; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
+
+define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
+; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
+  ret void
+}
+
+declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
+declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
+declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
+declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
+declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll
new file mode 100644
index 00000000000000..cca71d409d258d
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/getelementptr-kernel-arg-char.ll
@@ -0,0 +1,31 @@
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer Workgroup %[[#INT8]]
+; CHECK-DAG: %[[#PTRVINT8:]] = OpTypePointer Workgroup %[[#VINT8]]
+; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#INT64]] 1
+
+; CHECK: %[[#PARAM1:]] = OpFunctionParameter %[[#PTRVINT8]]
+define spir_kernel void @test1(ptr addrspace(3) %address) !kernel_arg_type !1 {
+; CHECK: %[[#BITCAST1:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM1]]
+; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST1]] %[[#CONST]]
+  %cast = bitcast ptr addrspace(3) %address to ptr addrspace(3)
+  %gep = getelementptr inbounds i8, ptr addrspace(3) %cast, i64 1
+  ret void
+}
+
+; CHECK: %[[#PARAM2:]] = OpFunctionParameter %[[#PTRVINT8]]
+define spir_kernel void @test2(ptr addrspace(3) %address) !kernel_arg_type !1 {
+; CHECK: %[[#BITCAST2:]] = OpBitcast %[[#PTRINT8]] %[[#PARAM2]]
+; CHECK: %[[#]] = OpInBoundsPtrAccessChain %[[#PTRINT8]] %[[#BITCAST2]] %[[#CONST]]
+  %gep = getelementptr inbounds i8, ptr addrspace(3) %address, i64 1
+  ret void
+}
+
+declare spir_func <2 x i8> @_Z6vload2mPU3AS3Kc(i64, ptr addrspace(3))
+
+!1 = !{!"char2*"}


        


More information about the llvm-commits mailing list