[llvm] 47e996d - [SPIR-V] Fix OpVariable instructions place in a function (#87554)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 4 01:50:39 PDT 2024


Author: Vyacheslav Levytskyy
Date: 2024-04-04T10:50:35+02:00
New Revision: 47e996d89d4d1e229451594d4b0752b71e8e231c

URL: https://github.com/llvm/llvm-project/commit/47e996d89d4d1e229451594d4b0752b71e8e231c
DIFF: https://github.com/llvm/llvm-project/commit/47e996d89d4d1e229451594d4b0752b71e8e231c.diff

LOG: [SPIR-V] Fix OpVariable instructions place in a function (#87554)

This PR:
* fixes OpVariable instructions place in a function (see
https://github.com/llvm/llvm-project/issues/66261),
* improves type inference,
* helps avoiding unneeded bitcasts when validating function call's

This allows to improve existing and add new test cases with more strict
checks. OpVariable fix refers to "All OpVariable instructions in a
function must be the first instructions in the first block" requirement
from SPIR-V spec.

Added: 
    llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll

Modified: 
    llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
    llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
    llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
    llvm/test/CodeGen/SPIRV/OpVariable_order.ll
    llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 1674cef7cb8270..9e4ba2191366b3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -243,8 +243,12 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
       continue;
 
     MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
-    SPIRVType *ElementType = GR->getOrCreateSPIRVType(
-        cast<ConstantAsMetadata>(VMD->getMetadata())->getType(), MIRBuilder);
+    Type *ElementTy = cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
+    if (isUntypedPointerTy(ElementTy))
+      ElementTy =
+          TypedPointerType::get(IntegerType::getInt8Ty(II->getContext()),
+                                getPointerAddressSpace(ElementTy));
+    SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
     return GR->getOrCreateSPIRVPointerType(
         ElementType, MIRBuilder,
         addressSpaceToStorageClass(

diff  --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index e0099e52944725..ac799374adce8c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -47,7 +47,7 @@ class SPIRVGlobalRegistry {
   DenseMap<const MachineOperand *, const Function *> InstrToFunction;
   // Maps Functions to their calls (in a form of the machine instruction,
   // OpFunctionCall) that happened before the definition is available
-  DenseMap<const Function *, SmallVector<MachineInstr *>> ForwardCalls;
+  DenseMap<const Function *, SmallPtrSet<MachineInstr *, 8>> ForwardCalls;
 
   // Look for an equivalent of the newType in the map. Return the equivalent
   // if it's found, otherwise insert newType to the map and return the type.
@@ -215,12 +215,12 @@ class SPIRVGlobalRegistry {
     if (It == ForwardCalls.end())
       ForwardCalls[F] = {MI};
     else
-      It->second.push_back(MI);
+      It->second.insert(MI);
   }
 
   // Map a Function to the vector of machine instructions that represents
   // forward function calls or to nullptr if not found.
-  SmallVector<MachineInstr *> *getForwardCalls(const Function *F) {
+  SmallPtrSet<MachineInstr *, 8> *getForwardCalls(const Function *F) {
     auto It = ForwardCalls.find(F);
     return It == ForwardCalls.end() ? nullptr : &It->second;
   }

diff  --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 90a31551f45a23..d450078d793fb7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -193,7 +193,7 @@ void validateForwardCalls(const SPIRVSubtarget &STI,
                           MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR,
                           MachineInstr &FunDef) {
   const Function *F = GR.getFunctionByDefinition(&FunDef);
-  if (SmallVector<MachineInstr *> *FwdCalls = GR.getForwardCalls(F))
+  if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F))
     for (MachineInstr *FunCall : *FwdCalls) {
       MachineRegisterInfo *CallMRI =
           &FunCall->getParent()->getParent()->getRegInfo();

diff  --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index f4525e713c987f..49749b56345306 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1825,7 +1825,24 @@ bool SPIRVInstructionSelector::selectAllocaArray(Register ResVReg,
 bool SPIRVInstructionSelector::selectFrameIndex(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
-  return BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpVariable))
+  // Change order of instructions if needed: all OpVariable instructions in a
+  // function must be the first instructions in the first block
+  MachineFunction *MF = I.getParent()->getParent();
+  MachineBasicBlock *MBB = &MF->front();
+  auto It = MBB->SkipPHIsAndLabels(MBB->begin()), E = MBB->end();
+  bool IsHeader = false;
+  unsigned Opcode;
+  for (; It != E && It != I; ++It) {
+    Opcode = It->getOpcode();
+    if (Opcode == SPIRV::OpFunction || Opcode == SPIRV::OpFunctionParameter) {
+      IsHeader = true;
+    } else if (IsHeader &&
+               !(Opcode == SPIRV::ASSIGN_TYPE || Opcode == SPIRV::OpLabel)) {
+      ++It;
+      break;
+    }
+  }
+  return BuildMI(*MBB, It, It->getDebugLoc(), TII.get(SPIRV::OpVariable))
       .addDef(ResVReg)
       .addUse(GR.getSPIRVTypeID(ResType))
       .addImm(static_cast<uint32_t>(SPIRV::StorageClass::Function))

diff  --git a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
index a4ca3aa709f0fa..6057bf38d4c4c4 100644
--- a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
+++ b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
@@ -1,10 +1,14 @@
-; REQUIRES: spirv-tools
-; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - -filetype=obj | not spirv-val 2>&1 | FileCheck %s
+; All OpVariable instructions in a function must be the first instructions in the first block
 
-; TODO(#66261): The SPIR-V backend should reorder OpVariable instructions so this doesn't fail,
-;     but in the meantime it's a good example of the spirv-val tool working as intended.
+; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-linux %s -o - -filetype=obj | spirv-val %}
 
-; CHECK: All OpVariable instructions in a function must be the first instructions in the first block.
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV-NEXT: OpLabel
+; CHECK-SPIRV-NEXT: OpVariable
+; CHECK-SPIRV-NEXT: OpVariable
+; CHECK-SPIRV: OpReturn
+; CHECK-SPIRV: OpFunctionEnd
 
 define void @main() #1 {
 entry:

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
index 1071d3443056cb..b039f80860daf6 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -10,22 +10,46 @@
 ; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
 ; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
 ; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+
 ; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
 ; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
 ; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
 ; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
-; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrLong:.*]] = OpTypePointer Generic %[[TyGenPtrLong]]
 ; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyChar:.*]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[TyGenPtrChar:.*]] = OpTypePointer Generic %[[TyChar]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrChar:.*]] = OpTypePointer Generic %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrGenPtrChar:.*]] = OpTypePointer Function %[[TyGenPtrChar]]
 ; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+
 ; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
 ; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+
 ; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+
+; CHECK-SPIRV: %[[HalfAddr:.*]] = OpPtrCastToGeneric
+; CHECK-SPIRV-NEXT: %[[HalfAddrCasted:.*]] = OpBitcast %[[TyGenPtrLong]] %[[HalfAddr]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[HalfAddrCasted]] %[[Const3]]
+
+; CHECK-SPIRV: %[[DblAddr:.*]] = OpPtrCastToGeneric
+; CHECK-SPIRV-NEXT: %[[DblAddrCasted:.*]] = OpBitcast %[[TyGenPtrLong]] %[[DblAddr]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[DblAddrCasted]] %[[Const3]]
+
 ; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
 ; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
 ; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+
+; CHECK-SPIRV: %[[ObjectAddr:.*]] = OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: %[[ToGeneric:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[ObjectAddr]]
+; CHECK-SPIRV-NEXT: %[[Casted:.*]] = OpBitcast %[[TyGenPtrPtrLong]] %[[ToGeneric]]
+; CHECK-SPIRV-NEXT: OpStore %[[Casted]] %[[StubObj]]
+
 ; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
 ; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
 ; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+
 ; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
 
 define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {

diff  --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll
new file mode 100644
index 00000000000000..edb31ffeee8e86
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-call-no-bitcast.ll
@@ -0,0 +1,60 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyChar:.*]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyGenPtrChar:.*]] = OpTypePointer Generic %[[TyChar]]
+; CHECK-SPIRV-DAG: %[[TyFunBar:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[TyGenPtrPtrChar:.*]] = OpTypePointer Generic %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyFunFoo:.*]] = OpTypeFunction %[[TyVoid]] %[[TyLong]] %[[TyGenPtrPtrChar]] %[[TyGenPtrPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyStruct:.*]] = OpTypeStruct %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const100:.*]] = OpConstant %[[TyLong]] 100
+; CHECK-SPIRV-DAG: %[[TyFunPtrGenPtrChar:.*]] = OpTypePointer Function %[[TyGenPtrChar]]
+; CHECK-SPIRV-DAG: %[[TyPtrStruct:.*]] = OpTypePointer Generic %[[TyStruct]]
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+
+; CHECK-SPIRV: %[[Bar:.*]] = OpFunction %[[TyVoid]] None %[[TyFunBar]]
+; CHECK-SPIRV: %[[BarArg:.*]] = OpFunctionParameter %[[TyGenPtrChar]]
+; CHECK-SPIRV-NEXT: OpLabel
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV-NEXT: OpVariable %[[TyFunPtrGenPtrChar]] Function
+; CHECK-SPIRV: %[[Var1:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[#]]
+; CHECK-SPIRV: %[[Var2:.*]] = OpPtrCastToGeneric %[[TyGenPtrPtrChar]] %[[#]]
+; CHECK-SPIRV: OpStore %[[#]] %[[BarArg]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[Foo]] %[[Const100]] %[[Var1]] %[[Var2]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[TyVoid]] %[[Foo]] %[[Const100]] %[[Var2]] %[[Var1]]
+
+; CHECK-SPIRV: %[[Foo]] = OpFunction %[[TyVoid]] None %[[TyFunFoo]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyGenPtrPtrChar]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[TyGenPtrPtrChar]]
+
+%class.CustomType = type { i64 }
+
+define linkonce_odr dso_local spir_func void @bar(ptr addrspace(4) noundef %first) {
+entry:
+  %first.addr = alloca ptr addrspace(4)
+  %first.addr.ascast = addrspacecast ptr %first.addr to ptr addrspace(4)
+  %temp = alloca ptr addrspace(4), align 8
+  %temp.ascast = addrspacecast ptr %temp to ptr addrspace(4)
+  store ptr addrspace(4) %first, ptr %first.addr
+  call spir_func void @foo(i64 noundef 100, ptr addrspace(4) noundef dereferenceable(8) %first.addr.ascast, ptr addrspace(4) noundef dereferenceable(8) %temp.ascast)
+  call spir_func void @foo(i64 noundef 100, ptr addrspace(4) noundef dereferenceable(8) %temp.ascast, ptr addrspace(4) noundef dereferenceable(8) %first.addr.ascast)
+  %var = alloca ptr addrspace(4), align 8
+  ret void
+}
+
+define linkonce_odr dso_local spir_func void @foo(i64 noundef %offset, ptr addrspace(4) noundef dereferenceable(8) %in_acc1, ptr addrspace(4) noundef dereferenceable(8) %out_acc1) {
+entry:
+  %r0 = load ptr addrspace(4), ptr addrspace(4) %in_acc1
+  %arrayidx = getelementptr inbounds %class.CustomType, ptr addrspace(4) %r0, i64 42
+  %r1 = load i64, ptr addrspace(4) %arrayidx
+  %r3 = load ptr addrspace(4), ptr addrspace(4) %out_acc1
+  %r4 = getelementptr %class.CustomType, ptr addrspace(4) %r3, i64 43
+  store i64 %r1, ptr addrspace(4) %r4
+  ret void
+}
+


        


More information about the llvm-commits mailing list