[llvm] [SPIR-V] Allow non-const arguments in a Group builtin that requires a boolean argument (PR #102902)
Vyacheslav Levytskyy via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 12 06:48:52 PDT 2024
https://github.com/VyacheslavLevytskyy created https://github.com/llvm/llvm-project/pull/102902
This PR resolves a TODO in `generateGroupInst()` (`lib/Target/SPIRV/SPIRVBuiltins.cpp`) and Issues https://github.com/llvm/llvm-project/issues/97311 and https://github.com/llvm/llvm-project/issues/97312 by implementing support for non-const arguments in a Group builtin that requires a boolean argument.
>From 5018620b74e98321eb9a77e64763e89884676401 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Mon, 12 Aug 2024 06:46:21 -0700
Subject: [PATCH] Allow non-const arguments in a Group builtin that requires a
boolean argument
---
llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 34 ++++++---
.../SPIRV/transcoding/OpGroupAllAny.ll | 72 ++++++++++++++++++-
2 files changed, 93 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 1609576c038d01..e3fe37e71975e9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1091,16 +1091,30 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
Register Arg0;
if (GroupBuiltin->HasBoolArg) {
- Register ConstRegister = Call->Arguments[0];
- auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
- (void)ArgInstruction;
- // TODO: support non-constant bool values.
- assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
- "Only constant bool value args are supported");
- if (GR->getSPIRVTypeForVReg(Call->Arguments[0])->getOpcode() !=
- SPIRV::OpTypeBool)
- Arg0 = GR->buildConstantInt(getIConstVal(ConstRegister, MRI), MIRBuilder,
- GR->getOrCreateSPIRVBoolType(MIRBuilder));
+ SPIRVType *BoolType = GR->getOrCreateSPIRVBoolType(MIRBuilder);
+ Register BoolReg = Call->Arguments[0];
+ SPIRVType *BoolRegType = GR->getSPIRVTypeForVReg(BoolReg);
+ if (!BoolRegType)
+ report_fatal_error("Can't find a register's type definition");
+ MachineInstr *ArgInstruction = getDefInstrMaybeConstant(BoolReg, MRI);
+ if (ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT) {
+ if (BoolRegType->getOpcode() != SPIRV::OpTypeBool)
+ Arg0 = GR->buildConstantInt(getIConstVal(BoolReg, MRI), MIRBuilder,
+ BoolType);
+ } else {
+ if (BoolRegType->getOpcode() == SPIRV::OpTypeInt) {
+ Arg0 = MRI->createGenericVirtualRegister(LLT::scalar(1));
+ MRI->setRegClass(Arg0, &SPIRV::IDRegClass);
+ GR->assignSPIRVTypeToVReg(BoolType, Arg0, MIRBuilder.getMF());
+ MIRBuilder.buildICmp(CmpInst::ICMP_NE, Arg0, BoolReg,
+ GR->buildConstantInt(0, MIRBuilder, BoolRegType));
+ insertAssignInstr(Arg0, nullptr, BoolType, GR, MIRBuilder,
+ MIRBuilder.getMF().getRegInfo());
+ } else if (BoolRegType->getOpcode() != SPIRV::OpTypeBool) {
+ report_fatal_error("Expect a boolean argument");
+ }
+ // if BoolReg is a boolean register, we don't need to do anything
+ }
}
Register GroupResultRegister = Call->ReturnRegister;
diff --git a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
index 39f9e0581d2f8e..0b420697763ba5 100644
--- a/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
+++ b/llvm/test/CodeGen/SPIRV/transcoding/OpGroupAllAny.ll
@@ -8,11 +8,12 @@
; CHECK-SPIRV-DAG: %[[#BoolTypeID:]] = OpTypeBool
; CHECK-SPIRV-DAG: %[[#True:]] = OpConstantTrue %[[#BoolTypeID]]
; CHECK-SPIRV-DAG: %[[#False:]] = OpConstantFalse %[[#BoolTypeID]]
+
+; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#True]]
; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#True]]
; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#False]]
-
define spir_kernel void @test(i32 addrspace(1)* nocapture readnone %i) {
entry:
%call = tail call spir_func i32 @_Z14work_group_alli(i32 5)
@@ -22,8 +23,73 @@ entry:
ret void
}
-declare spir_func i32 @_Z14work_group_alli(i32)
-declare spir_func i32 @_Z14work_group_anyi(i32)
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: %[[#]] = OpGroupAny %[[#BoolTypeID]] %[[#]] %[[#]]
+; CHECK-SPIRV: %[[#]] = OpGroupAll %[[#BoolTypeID]] %[[#]] %[[#]]
+define spir_kernel void @test_nonconst_any(ptr addrspace(1) %input, ptr addrspace(1) %output) #0 !kernel_arg_addr_space !7 !kernel_arg_access_qual !8 !kernel_arg_type !9 !kernel_arg_type_qual !10 !kernel_arg_base_type !9 !spirv.ParameterDecorations !11 {
+entry:
+ %r0 = call spir_func i64 @_Z13get_global_idj(i32 0)
+ %r1 = insertelement <3 x i64> undef, i64 %r0, i32 0
+ %r2 = call spir_func i64 @_Z13get_global_idj(i32 1)
+ %r3 = insertelement <3 x i64> %r1, i64 %r2, i32 1
+ %r4 = call spir_func i64 @_Z13get_global_idj(i32 2)
+ %r5 = insertelement <3 x i64> %r3, i64 %r4, i32 2
+ %call = extractelement <3 x i64> %r5, i32 0
+ %conv = trunc i64 %call to i32
+ %idxprom = sext i32 %conv to i64
+ %arrayidx = getelementptr inbounds float, ptr addrspace(1) %input, i64 %idxprom
+ %r6 = load float, ptr addrspace(1) %arrayidx, align 4
+ %add = add nsw i32 %conv, 1
+ %idxprom1 = sext i32 %add to i64
+ %arrayidx2 = getelementptr inbounds float, ptr addrspace(1) %input, i64 %idxprom1
+ %r7 = load float, ptr addrspace(1) %arrayidx2, align 4
+ %cmp = fcmp ogt float %r6, %r7
+ %conv3 = select i1 %cmp, i32 1, i32 0
+ %r8 = icmp ne i32 %conv3, 0
+ %r9 = zext i1 %r8 to i32
+ %r10 = call spir_func i32 @_Z14work_group_anyi(i32 %r9)
+ %call41 = icmp ne i32 %r10, 0
+ %call4 = select i1 %call41, i32 1, i32 0
+ %idxprom5 = sext i32 %conv to i64
+ %arrayidx6 = getelementptr inbounds i32, ptr addrspace(1) %output, i64 %idxprom5
+ store i32 %call4, ptr addrspace(1) %arrayidx6, align 4
+ %r11 = call spir_func i32 @_Z14work_group_alli(i32 %r9)
+ %call42 = icmp ne i32 %r11, 0
+ %call5 = select i1 %call42, i32 1, i32 0
+ store i32 %call5, ptr addrspace(1) %arrayidx6, align 4
+ ret void
+}
+
+declare spir_func i64 @_Z13get_global_idj(i32) #1
+
+declare spir_func i32 @_Z14work_group_alli(i32) #2
+declare spir_func i32 @_Z14work_group_anyi(i32) #2
declare spir_func i1 @__spirv_GroupAll(i32, i1)
declare spir_func i1 @__spirv_GroupAny(i32, i1)
+
+attributes #0 = { nounwind }
+attributes #1 = { nounwind willreturn memory(none) }
+attributes #2 = { convergent nounwind }
+
+!spirv.MemoryModel = !{!0}
+!opencl.enable.FP_CONTRACT = !{}
+!spirv.Source = !{!1}
+!opencl.spir.version = !{!2}
+!opencl.ocl.version = !{!3}
+!opencl.used.extensions = !{!4}
+!opencl.used.optional.core.features = !{!5}
+!spirv.Generator = !{!6}
+
+!0 = !{i32 2, i32 2}
+!1 = !{i32 3, i32 300000}
+!2 = !{i32 2, i32 0}
+!3 = !{i32 3, i32 0}
+!4 = !{!"cl_khr_subgroups"}
+!5 = !{}
+!6 = !{i16 6, i16 14}
+!7 = !{i32 1, i32 1}
+!8 = !{!"none", !"none"}
+!9 = !{!"float*", !"int*"}
+!10 = !{!"", !""}
+!11 = !{!5, !5}
More information about the llvm-commits
mailing list