[llvm] [SPIR-V] Explicitly emit vector element count for OpenCL vloadn calls (PR #81148)
Michal Paszkowski via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 8 11:28:39 PST 2024
https://github.com/michalpaszkowski updated https://github.com/llvm/llvm-project/pull/81148
>From 728895e1c7f5301327d43100b32b6739a8eadcfd Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Thu, 8 Feb 2024 07:23:03 -0800
Subject: [PATCH 1/2] [SPIR-V] Explicitly emit vector element count for OpenCL
vloadn calls
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 2 +
llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 16 +++-
.../SPIRV/opencl/basic/vstore_private.ll | 95 -------------------
llvm/test/CodeGen/SPIRV/opencl/vload2.ll | 40 ++++++++
4 files changed, 53 insertions(+), 100 deletions(-)
delete mode 100644 llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
create mode 100644 llvm/test/CodeGen/SPIRV/opencl/vload2.ll
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index e4593e7db90e8..572a9afe14b26 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -114,6 +114,7 @@ struct VectorLoadStoreBuiltin {
StringRef Name;
InstructionSet::InstructionSet Set;
uint32_t Number;
+ uint32_t ElementCount;
bool IsRounded;
FPRoundingMode::FPRoundingMode RoundingMode;
};
@@ -1851,6 +1852,7 @@ static bool generateVectorLoadStoreInst(const SPIRV::IncomingCall *Call,
.addImm(Builtin->Number);
for (auto Argument : Call->Arguments)
MIB.addUse(Argument);
+ MIB.addImm(Builtin->ElementCount);
// Rounding mode should be passed as a last argument in the MI for builtins
// like "vstorea_halfn_r".
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 8acd4691787e4..63ca0a909b69c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -1046,18 +1046,24 @@ class VectorLoadStoreBuiltin<string name, InstructionSet set, int number> {
string Name = name;
InstructionSet Set = set;
bits<32> Number = number;
+ bits<32> ElementCount = !cond(!not(!eq(!find(name, "2"), -1)) : 2,
+ !not(!eq(!find(name, "3"), -1)) : 3,
+ !not(!eq(!find(name, "4"), -1)) : 4,
+ !not(!eq(!find(name, "8"), -1)) : 8,
+ !not(!eq(!find(name, "16"), -1)) : 16,
+ true : 1);
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
- !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
- !not(!eq(!find(name, "_rtp"), -1)) : RTP,
- !not(!eq(!find(name, "_rtn"), -1)) : RTN,
- true : RTE);
+ !not(!eq(!find(name, "_rtz"), -1)) : RTZ,
+ !not(!eq(!find(name, "_rtp"), -1)) : RTP,
+ !not(!eq(!find(name, "_rtn"), -1)) : RTN,
+ true : RTE);
}
// Table gathering all the vector data load/store builtins.
def VectorLoadStoreBuiltins : GenericTable {
let FilterClass = "VectorLoadStoreBuiltin";
- let Fields = ["Name", "Set", "Number", "IsRounded", "RoundingMode"];
+ let Fields = ["Name", "Set", "Number", "ElementCount", "IsRounded", "RoundingMode"];
string TypeOf_Set = "InstructionSet";
string TypeOf_RoundingMode = "FPRoundingMode";
}
diff --git a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll b/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
deleted file mode 100644
index 40f1d59e4365e..0000000000000
--- a/llvm/test/CodeGen/SPIRV/opencl/basic/vstore_private.ll
+++ /dev/null
@@ -1,95 +0,0 @@
-; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
-
-; TODO(#60133): Requires updates following opaque pointer migration.
-; XFAIL: *
-
-; CHECK: %[[#i16_ty:]] = OpTypeInt 16 0
-; CHECK: %[[#v4xi16_ty:]] = OpTypeVector %[[#i16_ty]] 4
-; CHECK: %[[#pv4xi16_ty:]] = OpTypePointer Function %[[#v4xi16_ty]]
-; CHECK: %[[#i16_const0:]] = OpConstant %[[#i16_ty]] 0
-; CHECK: %[[#i16_undef:]] = OpUndef %[[#i16_ty]]
-; CHECK: %[[#comp_const:]] = OpConstantComposite %[[#v4xi16_ty]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_const0]] %[[#i16_undef]]
-
-; CHECK: %[[#r:]] = OpInBoundsPtrAccessChain
-; CHECK: %[[#r2:]] = OpBitcast %[[#pv4xi16_ty]] %[[#r]]
-; CHECK: OpStore %[[#r2]] %[[#comp_const]] Aligned 8
-
-define spir_kernel void @test_fn(i16 addrspace(1)* %srcValues, i32 addrspace(1)* %offsets, <3 x i16> addrspace(1)* %destBuffer, i32 %alignmentOffset) {
-entry:
- %sPrivateStorage = alloca [42 x <3 x i16>], align 8
- %0 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
- %1 = bitcast i8* %0 to i8*
- call void @llvm.lifetime.start.p0i8(i64 336, i8* %1)
- %2 = call spir_func <3 x i64> @BuiltInGlobalInvocationId()
- %call = extractelement <3 x i64> %2, i32 0
- %conv = trunc i64 %call to i32
- %idxprom = sext i32 %conv to i64
- %arrayidx = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 %idxprom
- %storetmp = bitcast <3 x i16>* %arrayidx to <4 x i16>*
- store <4 x i16> <i16 0, i16 0, i16 0, i16 undef>, <4 x i16>* %storetmp, align 8
- %conv1 = sext i32 %conv to i64
- %call2 = call spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64 %conv1, i16 addrspace(1)* %srcValues, i32 3)
- %idxprom3 = sext i32 %conv to i64
- %arrayidx4 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom3
- %3 = load i32, i32 addrspace(1)* %arrayidx4, align 4
- %conv5 = zext i32 %3 to i64
- %arraydecay = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
- %4 = bitcast <3 x i16>* %arraydecay to i16*
- %idx.ext = zext i32 %alignmentOffset to i64
- %add.ptr = getelementptr inbounds i16, i16* %4, i64 %idx.ext
- call spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16> %call2, i64 %conv5, i16* %add.ptr)
- %arraydecay6 = getelementptr inbounds [42 x <3 x i16>], [42 x <3 x i16>]* %sPrivateStorage, i64 0, i64 0
- %5 = bitcast <3 x i16>* %arraydecay6 to i16*
- %idxprom7 = sext i32 %conv to i64
- %arrayidx8 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom7
- %6 = load i32, i32 addrspace(1)* %arrayidx8, align 4
- %mul = mul i32 3, %6
- %idx.ext9 = zext i32 %mul to i64
- %add.ptr10 = getelementptr inbounds i16, i16* %5, i64 %idx.ext9
- %idx.ext11 = zext i32 %alignmentOffset to i64
- %add.ptr12 = getelementptr inbounds i16, i16* %add.ptr10, i64 %idx.ext11
- %7 = bitcast <3 x i16> addrspace(1)* %destBuffer to i16 addrspace(1)*
- %idxprom13 = sext i32 %conv to i64
- %arrayidx14 = getelementptr inbounds i32, i32 addrspace(1)* %offsets, i64 %idxprom13
- %8 = load i32, i32 addrspace(1)* %arrayidx14, align 4
- %mul15 = mul i32 3, %8
- %idx.ext16 = zext i32 %mul15 to i64
- %add.ptr17 = getelementptr inbounds i16, i16 addrspace(1)* %7, i64 %idx.ext16
- %idx.ext18 = zext i32 %alignmentOffset to i64
- %add.ptr19 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr17, i64 %idx.ext18
- br label %for.cond
-
-for.cond: ; preds = %for.inc, %entry
- %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]
- %cmp = icmp ult i32 %i.0, 3
- br i1 %cmp, label %for.body, label %for.end
-
-for.body: ; preds = %for.cond
- %idxprom21 = zext i32 %i.0 to i64
- %arrayidx22 = getelementptr inbounds i16, i16* %add.ptr12, i64 %idxprom21
- %9 = load i16, i16* %arrayidx22, align 2
- %idxprom23 = zext i32 %i.0 to i64
- %arrayidx24 = getelementptr inbounds i16, i16 addrspace(1)* %add.ptr19, i64 %idxprom23
- store i16 %9, i16 addrspace(1)* %arrayidx24, align 2
- br label %for.inc
-
-for.inc: ; preds = %for.body
- %inc = add i32 %i.0, 1
- br label %for.cond
-
-for.end: ; preds = %for.cond
- %10 = bitcast [42 x <3 x i16>]* %sPrivateStorage to i8*
- %11 = bitcast i8* %10 to i8*
- call void @llvm.lifetime.end.p0i8(i64 336, i8* %11)
- ret void
-}
-
-declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i16> @OpenCL_vload3_i64_p1i16_i32(i64, i16 addrspace(1)*, i32)
-
-declare spir_func void @OpenCL_vstore3_v3i16_i64_p0i16(<3 x i16>, i64, i16*)
-
-declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture)
-
-declare spir_func <3 x i64> @BuiltInGlobalInvocationId()
diff --git a/llvm/test/CodeGen/SPIRV/opencl/vload2.ll b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
new file mode 100644
index 0000000000000..f7d380b96a3ef
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/opencl/vload2.ll
@@ -0,0 +1,40 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; This test only itends to check the vloadn builtin lowering.
+; The calls to the OpenCL builtins are not valid and will not pass SPIR-V validation.
+
+; CHECK-DAG: %[[#IMPORT:]] = OpExtInstImport "OpenCL.std"
+
+; CHECK-DAG: %[[#INT8:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#INT16:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#INT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#INT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#FLOAT:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#VINT8:]] = OpTypeVector %[[#INT8]] 2
+; CHECK-DAG: %[[#VINT16:]] = OpTypeVector %[[#INT16]] 2
+; CHECK-DAG: %[[#VINT32:]] = OpTypeVector %[[#INT32]] 2
+; CHECK-DAG: %[[#VINT64:]] = OpTypeVector %[[#INT64]] 2
+; CHECK-DAG: %[[#VFLOAT:]] = OpTypeVector %[[#FLOAT]] 2
+; CHECK-DAG: %[[#PTRINT8:]] = OpTypePointer CrossWorkgroup %[[#INT8]]
+
+; CHECK: %[[#OFFSET:]] = OpFunctionParameter %[[#INT64]]
+; CHECK: %[[#ADDRESS:]] = OpFunctionParameter %[[#PTRINT8]]
+
+define spir_kernel void @test_fn(i64 %offset, ptr addrspace(1) %address) {
+; CHECK: %[[#]] = OpExtInst %[[#VINT8]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+ %call1 = call spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT16]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+ %call2 = call spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT32]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+ %call3 = call spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VINT64]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+ %call4 = call spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64 %offset, ptr addrspace(1) %address)
+; CHECK: %[[#]] = OpExtInst %[[#VFLOAT]] %[[#IMPORT]] vloadn %[[#OFFSET]] %[[#ADDRESS]] 2
+ %call5 = call spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64 %offset, ptr addrspace(1) %address)
+ ret void
+}
+
+declare spir_func <2 x i8> @_Z6vload2mPU3AS1Kc(i64, ptr addrspace(1))
+declare spir_func <2 x i16> @_Z6vload2mPU3AS1Ks(i64, ptr addrspace(1))
+declare spir_func <2 x i32> @_Z6vload2mPU3AS1Ki(i64, ptr addrspace(1))
+declare spir_func <2 x i64> @_Z6vload2mPU3AS1Kl(i64, ptr addrspace(1))
+declare spir_func <2 x float> @_Z6vload2mPU3AS1Kf(i64, ptr addrspace(1))
>From 44bb5c0059773c5154d6844178dcc7b65f7b606f Mon Sep 17 00:00:00 2001
From: Michal Paszkowski <michal at paszkowski.org>
Date: Thu, 8 Feb 2024 11:28:14 -0800
Subject: [PATCH 2/2] [WIP]
---
llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 16 ++++++++++++----
1 file changed, 12 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 26a5d7a30f19d..ed70fd167343c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -290,7 +290,6 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
- bool AllowCastingToChar = false;
StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
@@ -308,7 +307,6 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Pointer = Arg;
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
OperandToReplace = 0;
- AllowCastingToChar = true;
} else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
@@ -390,10 +388,20 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
}
// Do not emit spv_ptrcast if it would cast to the default pointer element
- // type (i8) of the same address space.
- if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
+ // type (i8) of the same address space. In case of OpenCL kernels, make sure
+ // i8 is the pointer element type defined for the given kernel argument.
+ if (ExpectedElementType->isIntegerTy(8) &&
+ F->getCallingConv() != CallingConv::SPIR_KERNEL)
return;
+ Argument *Arg = dyn_cast<Argument>(Pointer);
+ if (ExpectedElementType->isIntegerTy(8) &&
+ F->getCallingConv() == CallingConv::SPIR_KERNEL && Arg) {
+ MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
+ if (ArgType && ArgType->getString().starts_with("uchar*"))
+ return;
+ }
+
// If this would be the first spv_ptrcast, the pointer's defining instruction
// requires spv_assign_ptr_type and does not already have one, do not emit
// spv_ptrcast and emit spv_assign_ptr_type instead.
More information about the llvm-commits
mailing list