[llvm] [SPIR-V] Fix validity of atomic instructions (PR #87051)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 29 08:27:08 PDT 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/87051

>From 22bfb0345d5f5275891ef4b3dd95635604d9abc6 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 29 Mar 2024 04:24:31 -0700
Subject: [PATCH 1/2] fix validity of atomic instructions; improve type
 inference

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp | 85 ++++++++++++++++---
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 42 +++++++++
 .../test/CodeGen/SPIRV/instructions/atomic.ll | 28 ++++--
 .../SPIRV/instructions/atomic_acqrel.ll       | 28 ++++--
 .../CodeGen/SPIRV/instructions/atomic_seq.ll  | 28 ++++--
 .../SPIRV/pointers/bitcast-fix-accesschain.ll | 37 ++++++++
 6 files changed, 211 insertions(+), 37 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-accesschain.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 7c5a38fa48d009..3679e060cbe6f2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -65,6 +65,10 @@ class SPIRVEmitIntrinsics
   Type *deduceElementType(Value *I);
   Type *deduceElementTypeHelper(Value *I);
   Type *deduceElementTypeHelper(Value *I, std::unordered_set<Value *> &Visited);
+  Type *deduceElementTypeByValueDeep(Type *ValueTy, Value *Operand,
+                                     std::unordered_set<Value *> &Visited);
+  Type *deduceElementTypeByUsersDeep(Value *Op,
+                                     std::unordered_set<Value *> &Visited);
 
   // deduce nested types of composites
   Type *deduceNestedTypeHelper(User *U);
@@ -176,6 +180,44 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
                        false);
 }
 
+// Set element pointer type to the given value of ValueTy and tries to
+// specify this type further (recursively) by Operand value, if needed.
+Type *SPIRVEmitIntrinsics::deduceElementTypeByValueDeep(
+    Type *ValueTy, Value *Operand, std::unordered_set<Value *> &Visited) {
+  Type *Ty = ValueTy;
+  if (Operand) {
+    if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
+      if (Type *NestedTy = deduceElementTypeHelper(Operand, Visited))
+        Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+    } else {
+      Ty = deduceNestedTypeHelper(dyn_cast<User>(Operand), Ty, Visited);
+    }
+  }
+  return Ty;
+}
+
+// Traverse User instructions to deduce an element pointer type of the operand.
+Type *SPIRVEmitIntrinsics::deduceElementTypeByUsersDeep(
+    Value *Op, std::unordered_set<Value *> &Visited) {
+  if (!Op || !isPointerTy(Op->getType()))
+    return nullptr;
+
+  if (auto PType = dyn_cast<TypedPointerType>(Op->getType()))
+    return PType->getElementType();
+
+  // maybe we already know operand's element type
+  if (Type *KnownTy = GR->findDeducedElementType(Op))
+    return KnownTy;
+
+  for (User *OpU : Op->users()) {
+    if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited))
+        return Ty;
+    }
+  }
+  return nullptr;
+}
+
 // Deduce and return a successfully deduced Type of the Instruction,
 // or nullptr otherwise.
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(Value *I) {
@@ -206,21 +248,27 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     Ty = Ref->getResultElementType();
   } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
-    Ty = Ref->getValueType();
-    if (Value *Op = Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr) {
-      if (auto *PtrTy = dyn_cast<PointerType>(Ty)) {
-        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
-          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
-      } else {
-        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), Ty, Visited);
-      }
-    }
+    Ty = deduceElementTypeByValueDeep(
+        Ref->getValueType(),
+        Ref->getNumOperands() > 0 ? Ref->getOperand(0) : nullptr, Visited);
   } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited);
   } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
     if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
         isPointerTy(Src) && isPointerTy(Dest))
       Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited);
+  } else if (auto *Ref = dyn_cast<AtomicCmpXchgInst>(I)) {
+    Value *Op = Ref->getNewValOperand();
+    Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
+  } else if (auto *Ref = dyn_cast<AtomicRMWInst>(I)) {
+    Value *Op = Ref->getValOperand();
+    Ty = deduceElementTypeByValueDeep(Op->getType(), Op, Visited);
+  } else if (auto *Ref = dyn_cast<PHINode>(I)) {
+    for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
+      Ty = deduceElementTypeByUsersDeep(Ref->getIncomingValue(i), Visited);
+      if (Ty)
+        break;
+    }
   }
 
   // remember the found relationship
@@ -293,6 +341,22 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper(
         return NewTy;
       }
     }
+  } else if (auto *VecTy = dyn_cast<VectorType>(OrigTy)) {
+    if (Value *Op = U->getNumOperands() > 0 ? U->getOperand(0) : nullptr) {
+      Type *OpTy = VecTy->getElementType();
+      Type *Ty = OpTy;
+      if (auto *PtrTy = dyn_cast<PointerType>(OpTy)) {
+        if (Type *NestedTy = deduceElementTypeHelper(Op, Visited))
+          Ty = TypedPointerType::get(NestedTy, PtrTy->getAddressSpace());
+      } else {
+        Ty = deduceNestedTypeHelper(dyn_cast<User>(Op), OpTy, Visited);
+      }
+      if (Ty != OpTy) {
+        Type *NewTy = VectorType::get(Ty, VecTy->getElementCount());
+        GR->addDeducedCompositeType(U, NewTy);
+        return NewTy;
+      }
+    }
   }
 
   return OrigTy;
@@ -578,7 +642,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
 
   // Handle calls to builtins (non-intrinsics):
   CallInst *CI = dyn_cast<CallInst>(I);
-  if (!CI || CI->isIndirectCall() || CI->getCalledFunction()->isIntrinsic())
+  if (!CI || CI->isIndirectCall() || CI->isInlineAsm() ||
+      !CI->getCalledFunction() || CI->getCalledFunction()->isIntrinsic())
     return;
 
   // collect information about formal parameter types
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 4f5c1dc4f90b0d..90a31551f45a23 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -201,6 +201,17 @@ void validateForwardCalls(const SPIRVSubtarget &STI,
     }
 }
 
+// Validation of an access chain.
+void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+                         SPIRVGlobalRegistry &GR, MachineInstr &I) {
+  SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg());
+  if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) {
+    SPIRVType *BaseElemType =
+        GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg());
+    validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType);
+  }
+}
+
 // TODO: the logic of inserting additional bitcast's is to be moved
 // to pre-IRTranslation passes eventually
 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
@@ -213,16 +224,47 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
          MBBI != MBBE;) {
       MachineInstr &MI = *MBBI++;
       switch (MI.getOpcode()) {
+      case SPIRV::OpAtomicLoad:
+      case SPIRV::OpAtomicExchange:
+      case SPIRV::OpAtomicCompareExchange:
+      case SPIRV::OpAtomicCompareExchangeWeak:
+      case SPIRV::OpAtomicIIncrement:
+      case SPIRV::OpAtomicIDecrement:
+      case SPIRV::OpAtomicIAdd:
+      case SPIRV::OpAtomicISub:
+      case SPIRV::OpAtomicSMin:
+      case SPIRV::OpAtomicUMin:
+      case SPIRV::OpAtomicSMax:
+      case SPIRV::OpAtomicUMax:
+      case SPIRV::OpAtomicAnd:
+      case SPIRV::OpAtomicOr:
+      case SPIRV::OpAtomicXor:
+        // for the above listed instructions
+        // OpAtomicXXX <ResType>, ptr %Op, ...
+        // implies that %Op is a pointer to <ResType>
       case SPIRV::OpLoad:
         // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
         validatePtrTypes(STI, MRI, GR, MI, 2,
                          GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
         break;
+      case SPIRV::OpAtomicStore:
+        // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj>
+        // implies that %Op points to the <Obj>'s type
+        validatePtrTypes(STI, MRI, GR, MI, 0,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg()));
+        break;
       case SPIRV::OpStore:
         // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
         validatePtrTypes(STI, MRI, GR, MI, 0,
                          GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()));
         break;
+      case SPIRV::OpPtrCastToGeneric:
+        validateAccessChain(STI, MRI, GR, MI);
+        break;
+      case SPIRV::OpInBoundsPtrAccessChain:
+        if (MI.getNumOperands() == 4)
+          validateAccessChain(STI, MRI, GR, MI);
+        break;
 
       case SPIRV::OpFunctionCall:
         // ensure there is no mismatch between actual and expected arg types:
diff --git a/llvm/test/CodeGen/SPIRV/instructions/atomic.ll b/llvm/test/CodeGen/SPIRV/instructions/atomic.ll
index 9715504fcc5d38..ce59bb2064027a 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/atomic.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/atomic.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: OpName [[ADD:%.*]] "test_add"
 ; CHECK-DAG: OpName [[SUB:%.*]] "test_sub"
@@ -20,7 +21,8 @@
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_add(i32* %ptr, i32 %val) {
@@ -32,7 +34,8 @@ define i32 @test_add(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_sub(i32* %ptr, i32 %val) {
@@ -44,7 +47,8 @@ define i32 @test_sub(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_min(i32* %ptr, i32 %val) {
@@ -56,7 +60,8 @@ define i32 @test_min(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_max(i32* %ptr, i32 %val) {
@@ -68,7 +73,8 @@ define i32 @test_max(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umin(i32* %ptr, i32 %val) {
@@ -80,7 +86,8 @@ define i32 @test_umin(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umax(i32* %ptr, i32 %val) {
@@ -92,7 +99,8 @@ define i32 @test_umax(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_and(i32* %ptr, i32 %val) {
@@ -104,7 +112,8 @@ define i32 @test_and(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_or(i32* %ptr, i32 %val) {
@@ -116,7 +125,8 @@ define i32 @test_or(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[A]] [[SCOPE]] [[RELAXED]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[BC_A]] [[SCOPE]] [[RELAXED]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_xor(i32* %ptr, i32 %val) {
diff --git a/llvm/test/CodeGen/SPIRV/instructions/atomic_acqrel.ll b/llvm/test/CodeGen/SPIRV/instructions/atomic_acqrel.ll
index 63c0ae75f5ecdd..950dfe417637fe 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/atomic_acqrel.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/atomic_acqrel.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: OpName [[ADD:%.*]] "test_add"
 ; CHECK-DAG: OpName [[SUB:%.*]] "test_sub"
@@ -20,7 +21,8 @@
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_add(i32* %ptr, i32 %val) {
@@ -32,7 +34,8 @@ define i32 @test_add(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_sub(i32* %ptr, i32 %val) {
@@ -44,7 +47,8 @@ define i32 @test_sub(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_min(i32* %ptr, i32 %val) {
@@ -56,7 +60,8 @@ define i32 @test_min(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_max(i32* %ptr, i32 %val) {
@@ -68,7 +73,8 @@ define i32 @test_max(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umin(i32* %ptr, i32 %val) {
@@ -80,7 +86,8 @@ define i32 @test_umin(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umax(i32* %ptr, i32 %val) {
@@ -92,7 +99,8 @@ define i32 @test_umax(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_and(i32* %ptr, i32 %val) {
@@ -104,7 +112,8 @@ define i32 @test_and(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_or(i32* %ptr, i32 %val) {
@@ -116,7 +125,8 @@ define i32 @test_or(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[A]] [[SCOPE]] [[ACQREL]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[BC_A]] [[SCOPE]] [[ACQREL]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_xor(i32* %ptr, i32 %val) {
diff --git a/llvm/test/CodeGen/SPIRV/instructions/atomic_seq.ll b/llvm/test/CodeGen/SPIRV/instructions/atomic_seq.ll
index f6a8fe1e6db18e..f142e012dcb744 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/atomic_seq.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/atomic_seq.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: OpName [[ADD:%.*]] "test_add"
 ; CHECK-DAG: OpName [[SUB:%.*]] "test_sub"
@@ -20,7 +21,8 @@
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicIAdd [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_add(i32* %ptr, i32 %val) {
@@ -32,7 +34,8 @@ define i32 @test_add(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicISub [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_sub(i32* %ptr, i32 %val) {
@@ -44,7 +47,8 @@ define i32 @test_sub(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_min(i32* %ptr, i32 %val) {
@@ -56,7 +60,8 @@ define i32 @test_min(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicSMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_max(i32* %ptr, i32 %val) {
@@ -68,7 +73,8 @@ define i32 @test_max(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMin [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umin(i32* %ptr, i32 %val) {
@@ -80,7 +86,8 @@ define i32 @test_umin(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicUMax [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_umax(i32* %ptr, i32 %val) {
@@ -92,7 +99,8 @@ define i32 @test_umax(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicAnd [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_and(i32* %ptr, i32 %val) {
@@ -104,7 +112,8 @@ define i32 @test_and(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicOr [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_or(i32* %ptr, i32 %val) {
@@ -116,7 +125,8 @@ define i32 @test_or(i32* %ptr, i32 %val) {
 ; CHECK-NEXT: [[A:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: [[B:%.*]] = OpFunctionParameter
 ; CHECK-NEXT: OpLabel
-; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[A]] [[SCOPE]] [[SEQ]] [[B]]
+; CHECK-NEXT: [[BC_A:%.*]] = OpBitcast %[[#]] [[A]]
+; CHECK-NEXT: [[R:%.*]] = OpAtomicXor [[I32Ty]] [[BC_A]] [[SCOPE]] [[SEQ]] [[B]]
 ; CHECK-NEXT: OpReturnValue [[R]]
 ; CHECK-NEXT: OpFunctionEnd
 define i32 @test_xor(i32* %ptr, i32 %val) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-accesschain.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-accesschain.ll
new file mode 100644
index 00000000000000..7fae6ca2c48cf1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-accesschain.ll
@@ -0,0 +1,37 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TYCHAR:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#TYCHARPTR:]] = OpTypePointer Function %[[#TYCHAR]]
+; CHECK-DAG: %[[#TYINT32:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYSTRUCTINT32:]] = OpTypeStruct %[[#TYINT32]]
+; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTINT32]] %[[#]]
+; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]]
+; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
+; CHECK-DAG: %[[#TYINT64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#TYINT64PTR:]] = OpTypePointer Function %[[#TYINT64]]
+; CHECK: OpFunction
+; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
+; CHECK: %[[#PTRTOCHAR:]] = OpBitcast %[[#TYCHARPTR]] %[[#PTRTOSTRUCT]]
+; CHECK-NEXT: OpInBoundsPtrAccessChain %[[#TYCHARPTR]] %[[#PTRTOCHAR]]
+; CHECK: OpFunction
+; CHECK: %[[#PTRTOSTRUCT2:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
+; CHECK: %[[#ELEM:]] = OpInBoundsPtrAccessChain %[[#TYSTRUCTPTR]] %[[#PTRTOSTRUCT2]]
+; CHECK-NEXT: %[[#TOLOAD:]] = OpBitcast %[[#TYINT64PTR]] %[[#ELEM]]
+; CHECK-NEXT: OpLoad %[[#TYINT64]] %[[#TOLOAD]]
+
+%struct.S = type { i32 }
+%struct.__wrapper_class = type { [7 x %struct.S] }
+
+define spir_kernel void @foo1(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+entry:
+  %elem = getelementptr inbounds i8, ptr %_arg_Arr, i64 0
+  ret void
+}
+
+define spir_kernel void @foo2(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+entry:
+  %elem = getelementptr inbounds %struct.__wrapper_class, ptr %_arg_Arr, i64 0
+  %data = load i64, ptr %elem
+  ret void
+}

>From 0cdc12208f8974b3cf58284719d9f8f6f53cf8d1 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 29 Mar 2024 08:26:56 -0700
Subject: [PATCH 2/2] fix double function header emission; improve type
 inference

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  7 +++++
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  4 +++
 llvm/test/CodeGen/SPIRV/ExecutionMode.ll      |  1 +
 .../pointers/type-deduce-by-call-complex.ll   | 29 +++++++++++++++++++
 4 files changed, 41 insertions(+)
 create mode 100644 llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-complex.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index ad4e72a3128b1e..1674cef7cb8270 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -251,6 +251,13 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
             cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
   }
 
+  // Replace PointerType with TypedPointerType to be able to map SPIR-V types to
+  // LLVM types in a consistent manner
+  if (isUntypedPointerTy(OriginalArgType)) {
+    OriginalArgType =
+        TypedPointerType::get(Type::getInt8Ty(F.getContext()),
+                              getPointerAddressSpace(OriginalArgType));
+  }
   return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 3679e060cbe6f2..b341fcb41d0312 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -994,6 +994,10 @@ Type *SPIRVEmitIntrinsics::deduceFunParamElementType(
     // maybe we already know operand's element type
     if (Type *KnownTy = GR->findDeducedElementType(OpArg))
       return KnownTy;
+    // try to deduce from the operand itself
+    Visited.clear();
+    if (Type *Ty = deduceElementTypeHelper(OpArg, Visited))
+      return Ty;
     // search in actual parameter's users
     for (User *OpU : OpArg->users()) {
       Instruction *Inst = dyn_cast<Instruction>(OpU);
diff --git a/llvm/test/CodeGen/SPIRV/ExecutionMode.ll b/llvm/test/CodeGen/SPIRV/ExecutionMode.ll
index 3e321e1c2bd280..180b7246952db5 100644
--- a/llvm/test/CodeGen/SPIRV/ExecutionMode.ll
+++ b/llvm/test/CodeGen/SPIRV/ExecutionMode.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: %[[#VOID:]] = OpTypeVoid
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-complex.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-complex.ll
new file mode 100644
index 00000000000000..ea7a22c31d0e85
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-complex.ll
@@ -0,0 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[Struct:.*]] = OpTypeStruct %[[Long]]
+; CHECK-SPIRV-DAG: %[[StructPtr:.*]] = OpTypePointer Generic %[[Struct]]
+; CHECK-SPIRV-DAG: %[[Function:.*]] = OpTypeFunction %[[Void]] %[[StructPtr]]
+; CHECK-SPIRV-DAG: %[[Const:.*]] = OpConstantNull %[[Struct]]
+; CHECK-SPIRV-DAG: %[[CrossStructPtr:.*]] = OpTypePointer CrossWorkgroup %[[Struct]]
+; CHECK-SPIRV-DAG: %[[Var:.*]] = OpVariable %[[CrossStructPtr]] CrossWorkgroup %[[Const]]
+; CHECK-SPIRV: %[[Foo:.*]] = OpFunction %[[Void]] None %[[Function]]
+; CHECK-SPIRV-NEXT: OpFunctionParameter %[[StructPtr]]
+; CHECK-SPIRV: %[[Casted:.*]] = OpPtrCastToGeneric %[[StructPtr]] %[[Var]]
+; CHECK-SPIRV-NEXT: OpFunctionCall %[[Void]] %[[Foo]] %[[Casted]]
+
+%struct.global_ctor_dtor = type { i32 }
+ at g1 = addrspace(1) global %struct.global_ctor_dtor zeroinitializer
+
+define linkonce_odr spir_func void @foo(ptr addrspace(4) %this) {
+entry:
+  ret void
+}
+
+define internal spir_func void @bar() {
+entry:
+  call spir_func void @foo(ptr addrspace(4) addrspacecast (ptr addrspace(1) @g1 to ptr addrspace(4)))
+  ret void
+}



More information about the llvm-commits mailing list