[llvm] [SPIR-V] Insert a bitcast before load/store instruction to keep SPIR-V code valid (PR #84069)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 7 03:53:24 PST 2024


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/84069

>From afcbb97922f521fa7be986eeb157bcef225cd5a5 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 5 Mar 2024 12:53:05 -0800
Subject: [PATCH 1/5] keep code valid by inserting additional bitcasts

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp |  7 ++
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   | 78 +++++++++++++++++++
 llvm/lib/Target/SPIRV/SPIRVISelLowering.h     | 12 ++-
 llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll   | 24 ++++++
 llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll  | 35 +++++++++
 .../SPIRV/constant/global-constants.ll        |  3 +
 llvm/test/CodeGen/SPIRV/spirv-load-store.ll   | 10 ++-
 7 files changed, 164 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e88298f52fbe18..fea9366efc3a58 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -517,6 +517,13 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
     MRI->setType(Reg, RegLLTy);
     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
+  } else {
+    // Our knowledge about the type may be updated.
+    // If that's the case, we need to update a type
+    // associated with the register.
+    SPIRVType *DefType = getSPIRVTypeForVReg(ResVReg);
+    if (!DefType || DefType != BaseType)
+      assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
   }
 
   // If it's a global variable with name, output OpName for it.
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 33c6aa242969de..27539422302ab7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -12,6 +12,13 @@
 
 #include "SPIRVISelLowering.h"
 #include "SPIRV.h"
+#include "SPIRVInstrInfo.h"
+#include "SPIRVRegisterBankInfo.h"
+#include "SPIRVRegisterInfo.h"
+#include "SPIRVSubtarget.h"
+#include "SPIRVTargetMachine.h"
+#include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
 
 #define DEBUG_TYPE "spirv-lower"
@@ -74,3 +81,74 @@ bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
   }
   return false;
 }
+
+// Insert a bitcast before the instruction to keep SPIR-V code valid
+// when there is a type mismatch between results and operand types.
+static void validatePtrTypes(const SPIRVSubtarget &STI,
+                             MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR,
+                             MachineInstr &I, SPIRVType *ResType,
+                             unsigned OpIdx) {
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(
+      TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+          ? TypeInst->getOperand(1).getReg()
+          : OpReg);
+  if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType == ResType)
+    return;
+  // There is a type mismatch between results and operand types
+  // and we insert a bitcast before the instruction to keep SPIR-V code valid
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineInstr *PrevI = I.getPrevNode();
+  MachineBasicBlock &MBB = *I.getParent();
+  MachineBasicBlock::iterator InsPt =
+      PrevI ? PrevI->getIterator() : MBB.begin();
+  MachineIRBuilder MIB(MBB, InsPt);
+  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
+  if (!GR.isBitcastCompatible(NewPtrType, OpType))
+    report_fatal_error(
+        "insert validation bitcast: incompatible result and operand types");
+  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+                 .addDef(NewReg)
+                 .addUse(GR.getSPIRVTypeID(NewPtrType))
+                 .addUse(OpReg)
+                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+                                   *STI.getRegBankInfo());
+  if (!Res)
+    report_fatal_error("insert validation bitcast: cannot constrain all uses");
+  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+  I.getOperand(OpIdx).setReg(NewReg);
+}
+
+void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
+  MachineRegisterInfo *MRI = &MF.getRegInfo();
+  SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
+  GR.setCurrentFunc(MF);
+  for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) {
+    MachineBasicBlock *MBB = &*I;
+    for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end();
+         MBBI != MBBE;) {
+      MachineInstr &MI = *MBBI++;
+      switch (MI.getOpcode()) {
+      case SPIRV::OpLoad:
+        // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType>
+        validatePtrTypes(STI, MRI, GR, MI,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg()), 2);
+        break;
+      case SPIRV::OpStore:
+        // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type
+        validatePtrTypes(STI, MRI, GR, MI,
+                         GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()), 0);
+        break;
+      }
+    }
+  }
+  TargetLowering::finalizeLowering(MF);
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
index d34f802e9d889f..b01571bfc1eeb5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.h
@@ -14,16 +14,19 @@
 #ifndef LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H
 #define LLVM_LIB_TARGET_SPIRV_SPIRVISELLOWERING_H
 
+#include "SPIRVGlobalRegistry.h"
 #include "llvm/CodeGen/TargetLowering.h"
 
 namespace llvm {
 class SPIRVSubtarget;
 
 class SPIRVTargetLowering : public TargetLowering {
+  const SPIRVSubtarget &STI;
+
 public:
   explicit SPIRVTargetLowering(const TargetMachine &TM,
-                               const SPIRVSubtarget &STI)
-      : TargetLowering(TM) {}
+                               const SPIRVSubtarget &ST)
+      : TargetLowering(TM), STI(ST) {}
 
   // Stop IRTranslator breaking up FMA instrs to preserve types information.
   bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
@@ -47,6 +50,11 @@ class SPIRVTargetLowering : public TargetLowering {
   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
                           MachineFunction &MF,
                           unsigned Intrinsic) const override;
+
+  // Call the default implementation and finalize target lowering by inserting
+  // extra instructions required to preserve validity of SPIR-V code imposed by
+  // the standard.
+  void finalizeLowering(MachineFunction &MF) const override;
 };
 } // namespace llvm
 
diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
new file mode 100644
index 00000000000000..a2b3cb9349aaf6
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
@@ -0,0 +1,24 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYSTRUCTLONG:]] = OpTypeStruct %[[#TYLONG]]
+; CHECK-DAG: %[[#TYARRAY:]] = OpTypeArray %[[#TYSTRUCTLONG]] %[[#]]
+; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYARRAY]]
+; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
+; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
+; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
+; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
+; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+%struct.S = type { i32 }
+%struct.__wrapper_class = type { [7 x %struct.S] }
+
+define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+entry:
+  %val = load i32, ptr %_arg_Arr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll
new file mode 100644
index 00000000000000..4d216df8514a46
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll
@@ -0,0 +1,35 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
+; CHECK-DAG: %[[#TYSTRUCT:]] = OpTypeStruct %[[#TYLONG]]
+; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 3
+; CHECK-DAG: %[[#TYSTRUCTPTR:]] = OpTypePointer Function %[[#TYSTRUCT]]
+; CHECK: OpFunction
+; CHECK: %[[#ARGPTR1:]] = OpFunctionParameter %[[#TYLONGPTR]]
+; CHECK: OpStore %[[#ARGPTR1]] %[[#CONST:]]
+; CHECK: OpFunction
+; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]]
+; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]]
+; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
+; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
+target triple = "spir64-unknown-unknown"
+
+%struct.S = type { i32 }
+%struct.__wrapper_class = type { [7 x %struct.S] }
+
+;define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
+define spir_kernel void @foo(%struct.S %arg, ptr %ptr) {
+entry:
+  store %struct.S %arg, ptr %ptr
+  ret void
+}
+
+define spir_kernel void @bar(ptr %ptr) {
+entry:
+  store i32 3, ptr %ptr
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
index 916c70628d0169..1e400accaec0c1 100644
--- a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
+++ b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
@@ -1,5 +1,8 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spir64"
+
 @global   = addrspace(1) constant i32 1 ; OpenCL global memory
 @constant = addrspace(2) constant i32 2 ; OpenCL constant memory
 @local    = addrspace(3) constant i32 3 ; OpenCL local memory
diff --git a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
index a82bf0ab2e01f6..9788f0a651c4d2 100644
--- a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
+++ b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
@@ -1,9 +1,13 @@
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
 ;; Translate SPIR-V friendly OpLoad and OpStore calls
 
-; CHECK: %[[#CONST:]] = OpConstant %[[#]] 42
-; CHECK: OpStore %[[#PTR:]] %[[#CONST]] Volatile|Aligned 4
-; CHECK: %[[#]] = OpLoad %[[#]] %[[#PTR]]
+; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#TYFLOAT:]] = OpTypeFloat 64
+; CHECK-DAG: %[[#TYFLOATPTR:]] = OpTypePointer CrossWorkgroup %[[#TYFLOAT]]
+; CHECK-DAG: %[[#CONST:]] = OpConstant %[[#TYLONG]] 42
+; CHECK: OpStore %[[#PTRTOLONG:]] %[[#CONST]] Volatile|Aligned 4
+; CHECK: %[[#PTRTOFLOAT:]] = OpBitcast %[[#TYFLOATPTR]] %[[#PTRTOLONG]]
+; CHECK: OpLoad %[[#TYFLOAT]] %[[#PTRTOFLOAT]]
 
 define weak_odr dso_local spir_kernel void @foo(i32 addrspace(1)* %var) {
 entry:

>From 671ad279fb8c0b6a218fded9eec1f2ed525dd5e0 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Tue, 5 Mar 2024 13:25:18 -0800
Subject: [PATCH 2/5] move tests

---
 llvm/test/CodeGen/SPIRV/{ => pointers}/bitcast-fix-load.ll  | 0
 llvm/test/CodeGen/SPIRV/{ => pointers}/bitcast-fix-store.ll | 0
 2 files changed, 0 insertions(+), 0 deletions(-)
 rename llvm/test/CodeGen/SPIRV/{ => pointers}/bitcast-fix-load.ll (100%)
 rename llvm/test/CodeGen/SPIRV/{ => pointers}/bitcast-fix-store.ll (100%)

diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
similarity index 100%
rename from llvm/test/CodeGen/SPIRV/bitcast-fix-load.ll
rename to llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
diff --git a/llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
similarity index 100%
rename from llvm/test/CodeGen/SPIRV/bitcast-fix-store.ll
rename to llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll

>From d711ca8d45bca0fd50f9667171432dc649dee041 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Wed, 6 Mar 2024 00:11:43 -0800
Subject: [PATCH 3/5] fix tests

---
 llvm/test/CodeGen/SPIRV/constant/global-constants.ll  | 4 +---
 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll  | 3 ---
 llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll | 3 ---
 llvm/test/CodeGen/SPIRV/spirv-load-store.ll           | 1 +
 4 files changed, 2 insertions(+), 9 deletions(-)

diff --git a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
index 1e400accaec0c1..74e28cbe7acb17 100644
--- a/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
+++ b/llvm/test/CodeGen/SPIRV/constant/global-constants.ll
@@ -1,7 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
-
-target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
-target triple = "spir64"
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 @global   = addrspace(1) constant i32 1 ; OpenCL global memory
 @constant = addrspace(2) constant i32 2 ; OpenCL constant memory
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
index a2b3cb9349aaf6..a30d0792e39988 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll
@@ -11,9 +11,6 @@
 ; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
 ; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
 
-target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
-target triple = "spir64-unknown-unknown"
-
 %struct.S = type { i32 }
 %struct.__wrapper_class = type { [7 x %struct.S] }
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
index 4d216df8514a46..302092bc1ed285 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
@@ -15,9 +15,6 @@
 ; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
 ; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
 
-target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
-target triple = "spir64-unknown-unknown"
-
 %struct.S = type { i32 }
 %struct.__wrapper_class = type { [7 x %struct.S] }
 
diff --git a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
index 9788f0a651c4d2..9188617312466d 100644
--- a/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
+++ b/llvm/test/CodeGen/SPIRV/spirv-load-store.ll
@@ -1,4 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 ;; Translate SPIR-V friendly OpLoad and OpStore calls
 
 ; CHECK-DAG: %[[#TYLONG:]] = OpTypeInt 32 0

>From b7e20efb3dcecd16204799ddc60ba64e64b31a0b Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 7 Mar 2024 03:36:18 -0800
Subject: [PATCH 4/5] introduce TypedPointerType to fix folded opaque pointers

---
 llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp |  4 +-
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 70 +++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   |  1 +
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   |  2 +
 .../SPIRV/pointers/bitcast-fix-store.ll       |  1 -
 5 files changed, 46 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index c83537bc7ae8a5..c0e64d3c1a954e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -20,6 +20,7 @@
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/InstVisitor.h"
 #include "llvm/IR/IntrinsicsSPIRV.h"
+#include "llvm/IR/TypedPointerType.h"
 
 #include <queue>
 
@@ -434,7 +435,8 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I) {
 
   for (unsigned OpIdx = 0; OpIdx < CI->arg_size(); OpIdx++) {
     Value *ArgOperand = CI->getArgOperand(OpIdx);
-    if (!isa<PointerType>(ArgOperand->getType()))
+    if (!isa<PointerType>(ArgOperand->getType()) &&
+        !isa<TypedPointerType>(ArgOperand->getType()))
       continue;
 
     // Constants (nulls/undefs) are handled in insertAssignPtrTypeIntrs()
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index fea9366efc3a58..1a489bd59d1488 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -20,6 +20,7 @@
 #include "SPIRVSubtarget.h"
 #include "SPIRVTargetMachine.h"
 #include "SPIRVUtils.h"
+#include "llvm/IR/TypedPointerType.h"
 
 using namespace llvm;
 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
@@ -420,9 +421,10 @@ Register
 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                              SPIRVType *SpvType) {
   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
-  const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
+  const TypedPointerType *LLVMPtrTy = cast<TypedPointerType>(LLVMTy);
   // Find a constant in DT or build a new one.
-  Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
+  Constant *CP = ConstantPointerNull::get(PointerType::get(
+      LLVMPtrTy->getElementType(), LLVMPtrTy->getAddressSpace()));
   Register Res = DT.find(CP, CurMF);
   if (!Res.isValid()) {
     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
@@ -712,33 +714,37 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
     }
     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
   }
-  if (auto PType = dyn_cast<PointerType>(Ty)) {
-    SPIRVType *SpvElementType;
-    // At the moment, all opaque pointers correspond to i8 element type.
-    // TODO: change the implementation once opaque pointers are supported
-    // in the SPIR-V specification.
-    SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
-    // Get access to information about available extensions
-    const SPIRVSubtarget *ST =
-        static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
-    auto SC = addressSpaceToStorageClass(PType->getAddressSpace(), *ST);
-    // Null pointer means we have a loop in type definitions, make and
-    // return corresponding OpTypeForwardPointer.
-    if (SpvElementType == nullptr) {
-      if (!ForwardPointerTypes.contains(Ty))
-        ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
-      return ForwardPointerTypes[PType];
-    }
-    // If we have forward pointer associated with this type, use its register
-    // operand to create OpTypePointer.
-    if (ForwardPointerTypes.contains(PType)) {
-      Register Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
-      return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
-    }
-
-    return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
+  unsigned AddrSpace = 0xFFFF;
+  if (auto PType = dyn_cast<TypedPointerType>(Ty))
+    AddrSpace = PType->getAddressSpace();
+  else if (auto PType = dyn_cast<PointerType>(Ty))
+    AddrSpace = PType->getAddressSpace();
+  else
+    report_fatal_error("Unable to convert LLVM type to SPIRVType", true);
+  SPIRVType *SpvElementType;
+  // At the moment, all opaque pointers correspond to i8 element type.
+  // TODO: change the implementation once opaque pointers are supported
+  // in the SPIR-V specification.
+  SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
+  // Get access to information about available extensions
+  const SPIRVSubtarget *ST =
+      static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
+  auto SC = addressSpaceToStorageClass(AddrSpace, *ST);
+  // Null pointer means we have a loop in type definitions, make and
+  // return corresponding OpTypeForwardPointer.
+  if (SpvElementType == nullptr) {
+    if (!ForwardPointerTypes.contains(Ty))
+      ForwardPointerTypes[Ty] = getOpTypeForwardPointer(SC, MIRBuilder);
+    return ForwardPointerTypes[Ty];
+  }
+  // If we have forward pointer associated with this type, use its register
+  // operand to create OpTypePointer.
+  if (ForwardPointerTypes.contains(Ty)) {
+    Register Reg = getSPIRVTypeID(ForwardPointerTypes[Ty]);
+    return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
   }
-  llvm_unreachable("Unable to convert LLVM type to SPIRVType");
+
+  return getOrCreateSPIRVPointerType(SpvElementType, MIRBuilder, SC);
 }
 
 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
@@ -1147,10 +1153,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
   unsigned AddressSpace = storageClassToAddressSpace(SC);
   Type *LLVMTy =
-      PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
+      TypedPointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
+  // check if this type is already available
   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
+  // create a new type
   auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
                      MIRBuilder.getDebugLoc(),
                      MIRBuilder.getTII().get(SPIRV::OpTypePointer))
@@ -1167,10 +1175,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
   unsigned AddressSpace = storageClassToAddressSpace(SC);
   Type *LLVMTy =
-      PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
+      TypedPointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
+  // check if this type is already available
   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
   if (Reg.isValid())
     return getSPIRVTypeForVReg(Reg);
+  // create a new type
   MachineBasicBlock &BB = *I.getParent();
   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
                  .addDef(createTypeVReg(CurMF->getRegInfo()))
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index f5a83072c19d76..9c0061d13fd0cf 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -34,6 +34,7 @@ class SPIRVGlobalRegistry {
   DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
       VRegToTypeMap;
 
+  // Map LLVM Type* to <MF, Reg>
   SPIRVGeneralDuplicatesTracker DT;
 
   DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 27539422302ab7..61748070fc0fb2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -127,6 +127,8 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   I.getOperand(OpIdx).setReg(NewReg);
 }
 
+// TODO: the logic of inserting additional bitcast's is to be moved
+// to pre-IRTranslation passes eventually
 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
   MachineRegisterInfo *MRI = &MF.getRegInfo();
   SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
diff --git a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
index 302092bc1ed285..4701f02ea33af3 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll
@@ -18,7 +18,6 @@
 %struct.S = type { i32 }
 %struct.__wrapper_class = type { [7 x %struct.S] }
 
-;define spir_kernel void @foo(ptr noundef byval(%struct.__wrapper_class) align 4 %_arg_Arr) {
 define spir_kernel void @foo(%struct.S %arg, ptr %ptr) {
 entry:
   store %struct.S %arg, ptr %ptr

>From 9d46a5924a8b00c560ebb1a238d680e9e9dbfb62 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 7 Mar 2024 03:53:11 -0800
Subject: [PATCH 5/5] clang-format and remove duplicated code

---
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 24 ++++---------------
 1 file changed, 5 insertions(+), 19 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 1a489bd59d1488..8556581996fede 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1152,8 +1152,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
     SPIRV::StorageClass::StorageClass SC) {
   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
   unsigned AddressSpace = storageClassToAddressSpace(SC);
-  Type *LLVMTy =
-      TypedPointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
+  Type *LLVMTy = TypedPointerType::get(const_cast<Type *>(PointerElementType),
+                                       AddressSpace);
   // check if this type is already available
   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
   if (Reg.isValid())
@@ -1170,24 +1170,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
-    SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
+    SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &,
     SPIRV::StorageClass::StorageClass SC) {
-  const Type *PointerElementType = getTypeForSPIRVType(BaseType);
-  unsigned AddressSpace = storageClassToAddressSpace(SC);
-  Type *LLVMTy =
-      TypedPointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
-  // check if this type is already available
-  Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
-  if (Reg.isValid())
-    return getSPIRVTypeForVReg(Reg);
-  // create a new type
-  MachineBasicBlock &BB = *I.getParent();
-  auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
-                 .addDef(createTypeVReg(CurMF->getRegInfo()))
-                 .addImm(static_cast<uint32_t>(SC))
-                 .addUse(getSPIRVTypeID(BaseType));
-  DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
-  return finishCreatingSPIRVType(LLVMTy, MIB);
+  MachineIRBuilder MIRBuilder(I);
+  return getOrCreateSPIRVPointerType(BaseType, MIRBuilder, SC);
 }
 
 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,



More information about the llvm-commits mailing list