[llvm] [SPIR-V] Explicitly emit vector element count for OpenCL vloadn calls (PR #81148)

Michal Paszkowski via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 07:36:02 PST 2024


https://github.com/michalpaszkowski created https://github.com/llvm/llvm-project/pull/81148

This pull request fixes an issue with missing vector element count immediate in OpExtInst calls such as:
```
%call = OpExtInst %v2uchar %1 vloadn %conv1 %add_ptr 2
```


>From 70b94ac73f54ba9084d3594e72d4514f116bcc6a Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Thu, 8 Feb 2024 07:23:03 -0800
Subject: [PATCH] [SPIR-V] Explicitly emit vector element count for OpenCL
 vloadn calls

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       |  4 +-
 llvm/lib/Target/SPIRV/SPIRVBuiltins.td        | 16 +++-
 .../SPIRV/opencl/basic/vstore_private.ll      | 95 -------------------
 llvm/test/CodeGen/SPIRV/opencl/vload2.ll      | 40 ++++++++
 4 files changed, 54 insertions(+), 101 deletions(-)
 delete mode 100644 llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/opencl/vload2.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e4593e7db90e8b..7a83ea77f199f8 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -114,6 +114,7 @@ struct VectorLoadStoreBuiltin {
   StringRef Name;
   InstructionSet::InstructionSet Set;
   uint32_t Number;
+  uint32_t ElementCount;
   bool IsRounded;
   FPRoundingMode::FPRoundingMode RoundingMode;
 };
@@ -1851,7 +1852,8 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
           .addImm(Builtin->Number);
   for (auto Argument : Call->Arguments)
     MIB.addUse(Argument);
-
+  MIB.addImm(Builtin->ElementCount);
+  
   // Rounding mode should be passed as a last argument in the MI for builtins
   // like "vstorea_halfn_r".
   if (Builtin->IsRounded)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 8acd4691787e4c..63ca0a909b69c3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1046,18 +1046,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
   string Name = name;
   InstructionSet Set = set;
   bits<32> Number = number;
+  bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
+                                !not(!eq(!find(name, "3"), -1)) : 3,
+                                !not(!eq(!find(name, "4"), -1)) : 4,
+                                !not(!eq(!find(name, "8"), -1)) : 8,
+                                !not(!eq(!find(name, "16"), -1)) : 16,
+                                true : 1);
   bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
   FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
-                                  !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
-                                  !not(!eq(!find(name, "_rtp"), -1)) : RTP,
-                                  !not(!eq(!find(name, "_rtn"), -1)) : RTN,
-                                  true : RTE);
+                                      !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
+                                      !not(!eq(!find(name, "_rtp"), -1)) : RTP,
+                                      !not(!eq(!find(name, "_rtn"), -1)) : RTN,
+                                      true : RTE);
 }
 
 // Table gathering all the vector data load/store builtins.
 def VectorLoadStoreBuiltins : GenericTable {
   let FilterClass = "VectorLoadStoreBuiltin";
-  let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
+  let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
   string TypeOf_Set = "InstructionSet";
   string TypeOf_RoundingMode = "FPRoundingMode";
 }
diff --git a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll b/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
deleted file mode 100644
index 40f1d59e4365e1..00000000000000
--- a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
+++ /dev/null
@@ -1,95 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-
-; TODO(#60133): Requires updates following opaque pointer migration.
-; XFAIL: *
-
-; CHECK: %[[#i16_ty:]] = OpTypeInt 16 0
-; CHECK: %[[#v4xi16_ty:]] = OpTypeVector %[[#i16_ty]] 4
-; CHECK: %[[#pv4xi16_ty:]] = OpTypePointer Function %[[#v4xi16_ty]]
-; CHECK: %[[#i16_const0:]] = OpConstant %[[#i16_ty]] 0
-; CHECK: %[[#i16_undef:]] = OpUndef %[[#i16_ty]]
-; CHECK: %[[#comp_const:]] = OpConstantComposite %[[#v4xi16_ty]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_undef]]
-
-; CHECK: %[[#r:]] = OpInBoundsPtrAccessChain
-; CHECK: %[[#r2:]] = OpBitcast %[[#pv4xi16_ty]] %[[#r]]
-; CHECK: OpStore %[[#r2]] %[[#comp_const]] Aligned 8
-
-define spir_kernel void @test_fn(i16 addrspace(1)* %srcValues, i32 addrspace(1)* %offsets, <3 x i16> addrspace(1)* %destBuffer, i32 %alignmentOffset) {
-entry:
-  %sPrivateStorage = alloca [42 x <3 x i16>], align 8
-  %0 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %1 = bitcast i8* %0 to i8*
-  call void @llvm.lifetime.start.p0i8(i64 336, i8* %1)
-  %2 = call spir_func <3 x i64> @BuiltInGlobalInvocationId()
-  %call = extractelement <3 x i64> %2, i32 0
-  %conv = trunc i64 %call to i32
-  %idxprom = sext i32 %conv to i64
-  %arrayidx = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 %idxprom
-  %storetmp = bitcast <3 x i16>* %arrayidx to <4 x i16>*
-  store <4 x i16> <i16 0, i16 0, i16 0, i16 undef>, <4 x i16>* %storetmp, align 8
-  %conv1 = sext i32 %conv to i64
-  %call2 = call spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64 %conv1, i16 addrspace(1)* %srcValues, i32 3)
-  %idxprom3 = sext i32 %conv to i64
-  %arrayidx4 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom3
-  %3 = load i32, i32 addrspace(1)* %arrayidx4, align 4
-  %conv5 = zext i32 %3 to i64
-  %arraydecay = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %4 = bitcast <3 x i16>* %arraydecay to i16*
-  %idx.ext = zext i32 %alignmentOffset to i64
-  %add.ptr = getelementptr inbounds i16, i16* %4, i64 %idx.ext
-  call spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16> %call2, i64 %conv5, i16* %add.ptr)
-  %arraydecay6 = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
-  %5 = bitcast <3 x i16>* %arraydecay6 to i16*
-  %idxprom7 = sext i32 %conv to i64
-  %arrayidx8 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom7
-  %6 = load i32, i32 addrspace(1)* %arrayidx8, align 4
-  %mul = mul i32 3, %6
-  %idx.ext9 = zext i32 %mul to i64
-  %add.ptr10 = getelementptr inbounds i16, i16* %5, i64 %idx.ext9
-  %idx.ext11 = zext i32 %alignmentOffset to i64
-  %add.ptr12 = getelementptr inbounds i16, i16* %add.ptr10, i64 %idx.ext11
-  %7 = bitcast <3 x i16> addrspace(1)* %destBuffer to i16 addrspace(1)*
-  %idxprom13 = sext i32 %conv to i64
-  %arrayidx14 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom13
-  %8 = load i32, i32 addrspace(1)* %arrayidx14, align 4
-  %mul15 = mul i32 3, %8
-  %idx.ext16 = zext i32 %mul15 to i64
-  %add.ptr17 = getelementptr inbounds i16, i16 addrspace(1)* %7, i64 %idx.ext16
-  %idx.ext18 = zext i32 %alignmentOffset to i64
-  %add.ptr19 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr17, i64 %idx.ext18
-  br label %for.cond
-
-for.cond:                                         ; preds = %for.inc, %entry
-  %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]
-  %cmp = icmp ult i32 %i.0, 3
-  br i1 %cmp, label %for.body, label %for.end
-
-for.body:                                         ; preds = %for.cond
-  %idxprom21 = zext i32 %i.0 to i64
-  %arrayidx22 = getelementptr inbounds i16, i16* %add.ptr12, i64 %idxprom21
-  %9 = load i16, i16* %arrayidx22, align 2
-  %idxprom23 = zext i32 %i.0 to i64
-  %arrayidx24 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr19, i64 %idxprom23
-  store i16 %9, i16 addrspace(1)* %arrayidx24, align 2
-  br label %for.inc
-
-for.inc:                                          ; preds = %for.body
-  %inc = add i32 %i.0, 1
-  br label %for.cond
-
-for.end:                                          ; preds = %for.cond
-  %10 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
-  %11 = bitcast i8* %10 to i8*
-  call void @llvm.lifetime.end.p0i8(i64 336, i8* %11)
-  ret void
-}
-
-declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64, i16 addrspace(1)*, i32)
-
-declare spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16>, i64, i16*)
-
-declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i64> @BuiltInGlobalInvocationId()
diff --git a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
new file mode 100644
index 00000000000000..f7d380b96a3ef0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; This test only itends to check the vloadn builtin lowering. 
+; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+
+; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
+; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
+; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
+; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
+; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
+
+define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
+; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+  %call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
+  ret void
+}
+
+declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
+declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
+declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
+declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
+declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))



More information about the llvm-commits mailing list