[llvm] [SPIR-V] Fix generation of gMIR vs. SPIR-V code from utility methods (PR #128159)

Vyacheslav Levytskyy via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 02:39:37 PST 2025


https://github.com/VyacheslavLevytskyy updated https://github.com/llvm/llvm-project/pull/128159

>From d579bb2192a9412693f19492d6d692d3083232bb Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Thu, 20 Feb 2025 13:42:20 -0800
Subject: [PATCH 1/2] fix generation of gMIR vs. SPIR-V

---
 llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp       | 16 ++++++----
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp   |  3 +-
 .../Target/SPIRV/SPIRVEmitNonSemanticDI.cpp   |  6 ++--
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp | 32 +++++++++++--------
 llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h   | 22 ++++++-------
 llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp   |  5 +--
 6 files changed, 46 insertions(+), 38 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 7b897f7e34c6f..c65166902550c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -476,8 +476,9 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
   if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
     unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
     uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
-    TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
-    FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
+    TrueConst =
+        GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
+    FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
   } else {
     TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
     FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
@@ -1457,7 +1458,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       ToTruncate = DefaultReg;
     }
     auto NewRegister =
-        GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+        GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
     MIRBuilder.buildCopy(DefaultReg, NewRegister);
   } else { // If it could be in range, we need to load from the given builtin.
     auto Vec3Ty =
@@ -1492,13 +1493,14 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
 
       // Use G_ICMP to check if idxVReg < 3.
-      MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
-                           GR->buildConstantInt(3, MIRBuilder, IndexType));
+      MIRBuilder.buildICmp(
+          CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
+          GR->buildConstantInt(3, MIRBuilder, IndexType, true));
 
       // Get constant for the default value (0 or 1 depending on which
       // function).
       Register DefaultRegister =
-          GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+          GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
 
       // Get a register for the selection result (possibly a new temporary one).
       Register SelectionResult = Call->ReturnRegister;
@@ -2277,7 +2279,7 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
       Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
                                            SpvFieldTy, *ST.getInstrInfo());
     } else {
-      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
     }
     if (!LocalWorkSize.isValid())
       LocalWorkSize = Const;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 78f6b188c45c1..e47dfddd55975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -669,7 +669,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // Make sure there's a valid return reg, even for functions returning void.
   if (!ResVReg.isValid())
     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
-  SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
+  SPIRVType *RetType = GR->assignTypeToVReg(
+      OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
 
   // Emit the call instruction and its args.
   auto MIB = MIRBuilder.buildInstr(CallOp)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
index b98cef0a4f07f..ee98af5cffe4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
@@ -193,7 +193,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
     };
 
     const SPIRVType *VoidTy =
-        GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder);
+        GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder,
+                                 SPIRV::AccessQualifier::ReadWrite, false);
 
     const auto EmitDIInstruction =
         [&](SPIRV::NonSemanticExtInst::NonSemanticExtInst Inst,
@@ -217,7 +218,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
         };
 
     const SPIRVType *I32Ty =
-        GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder);
+        GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder,
+                                 SPIRV::AccessQualifier::ReadWrite, false);
 
     const Register DwarfVersionReg =
         GR->buildConstantInt(DwarfVersion, MIRBuilder, I32Ty, false);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e2f1b211caa5c..a09474a21534e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -244,7 +244,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
     CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
     if (MIRBuilder)
-      assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
+      assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder,
+                       SPIRV::AccessQualifier::ReadWrite, true);
     else
       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
     DT.add(CI, CurMF, Res);
@@ -271,7 +272,8 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
     CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
     if (MIRBuilder)
-      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder,
+                       SPIRV::AccessQualifier::ReadWrite, true);
     else
       assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
     DT.add(CI, CurMF, Res);
@@ -878,12 +880,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
   });
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
-                                                MachineIRBuilder &MIRBuilder,
-                                                bool EmitIR) {
+SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
+    const StructType *Ty, MachineIRBuilder &MIRBuilder,
+    SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -1017,26 +1020,27 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
-    SPIRVType *El =
-        findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
+    SPIRVType *El = findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
+                                  MIRBuilder, AccQual, EmitIR);
     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
                            MIRBuilder);
   }
   if (Ty->isArrayTy()) {
-    SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
+    SPIRVType *El =
+        findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR);
     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
   }
   if (auto SType = dyn_cast<StructType>(Ty)) {
     if (SType->isOpaque())
       return getOpTypeOpaque(SType, MIRBuilder);
-    return getOpTypeStruct(SType, MIRBuilder, EmitIR);
+    return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR);
   }
   if (auto FType = dyn_cast<FunctionType>(Ty)) {
-    SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
+    SPIRVType *RetTy =
+        findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR);
     SmallVector<SPIRVType *, 4> ParamTypes;
-    for (const auto &t : FType->params()) {
-      ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
-    }
+    for (const auto &ParamTy : FType->params())
+      ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 0c94ec4df97f5..5ad705f197d5a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,13 +94,11 @@ class SPIRVGlobalRegistry {
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
-                             SPIRV::AccessQualifier::AccessQualifier AQ =
-                                 SPIRV::AccessQualifier::ReadWrite,
-                             bool EmitIR = true);
+                             SPIRV::AccessQualifier::AccessQualifier AQ,
+                             bool EmitIR);
   SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
-                           SPIRV::AccessQualifier::AccessQualifier accessQual =
-                               SPIRV::AccessQualifier::ReadWrite,
-                           bool EmitIR = true);
+                           SPIRV::AccessQualifier::AccessQualifier accessQual,
+                           bool EmitIR);
   SPIRVType *
   restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                         SPIRV::AccessQualifier::AccessQualifier AccessQual,
@@ -321,9 +319,8 @@ class SPIRVGlobalRegistry {
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
                               MachineIRBuilder &MIRBuilder,
-                              SPIRV::AccessQualifier::AccessQualifier AQ =
-                                  SPIRV::AccessQualifier::ReadWrite,
-                              bool EmitIR = true);
+                              SPIRV::AccessQualifier::AccessQualifier AQ,
+                              bool EmitIR);
   SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
                                  MachineInstr &I, const SPIRVInstrInfo &TII);
   SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
@@ -470,13 +467,14 @@ class SPIRVGlobalRegistry {
                              MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
-                            MachineIRBuilder &MIRBuilder, bool EmitIR = true);
+                            MachineIRBuilder &MIRBuilder, bool EmitIR);
 
   SPIRVType *getOpTypeOpaque(const StructType *Ty,
                              MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
-                             bool EmitIR = true);
+                             SPIRV::AccessQualifier::AccessQualifier AccQual,
+                             bool EmitIR);
 
   SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
                               SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -539,7 +537,7 @@ class SPIRVGlobalRegistry {
                                     SPIRVType *SpvType,
                                     const SPIRVInstrInfo &TII);
   Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                                    SPIRVType *SpvType, bool EmitIR = true);
+                                    SPIRVType *SpvType, bool EmitIR);
   Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                    SPIRVType *SpvType);
   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index d5b81bf46c804..f622be893919f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -192,7 +192,7 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
   // Insert a bitcast before the instruction to keep SPIR-V code valid.
   LLVMContext &Context = MF->getFunction().getContext();
   SPIRVType *NewPtrType =
-      createNewPtrType(GR, I, OpType, false, true, nullptr,
+      createNewPtrType(GR, I, OpType, false, false, nullptr,
                        TargetExtType::get(Context, "spirv.Event"));
   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
@@ -216,7 +216,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
   MachineIRBuilder MIB(I);
   LLVMContext &Context = MF->getFunction().getContext();
   SPIRVType *ElemType =
-      GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
+      GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
+                              SPIRV::AccessQualifier::ReadWrite, false);
   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
 }

>From cfe50c14555d4ebdb24e7f07ebc934a8f2279870 Mon Sep 17 00:00:00 2001
From: "Levytskyy, Vyacheslav" <vyacheslav.levytskyy at intel.com>
Date: Fri, 21 Feb 2025 02:39:27 -0800
Subject: [PATCH 2/2] add explicit emitir flag in utils

---
 llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp |  6 ++++--
 llvm/lib/Target/SPIRV/SPIRVUtils.cpp        | 19 ++++++++++++-------
 llvm/lib/Target/SPIRV/SPIRVUtils.h          |  9 ++++++---
 3 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index e47dfddd55975..e4f144eeffbbc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -557,10 +557,12 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
               RetTy =
                   TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
         }
-        setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
+        setRegClassType(ResVReg, RetTy, GR, MIRBuilder,
+                        SPIRV::AccessQualifier::ReadWrite, true);
       }
     } else {
-      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
+      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder,
+                                      SPIRV::AccessQualifier::ReadWrite, true);
     }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index ddc66f98829a9..c55b735314228 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -738,9 +738,12 @@ void setRegClassType(Register Reg, SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
 // no valid assigned class, set register LLT type and class according to the
 // SPIR-V type.
 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
-                     MachineIRBuilder &MIRBuilder, bool Force) {
-  setRegClassType(Reg, GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
-                  MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
+                     MachineIRBuilder &MIRBuilder,
+                     SPIRV::AccessQualifier::AccessQualifier AccessQual,
+                     bool EmitIR, bool Force) {
+  setRegClassType(Reg,
+                  GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR),
+                  GR, MIRBuilder.getMRI(), MIRBuilder.getMF(), Force);
 }
 
 // Create a virtual register and assign SPIR-V type to the register. Set
@@ -764,10 +767,12 @@ Register createVirtualRegister(SPIRVType *SpvType, SPIRVGlobalRegistry *GR,
 
 // Create a SPIR-V type, virtual register and assign SPIR-V type to the
 // register. Set register LLT type and class according to the SPIR-V type.
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
-                               MachineIRBuilder &MIRBuilder) {
-  return createVirtualRegister(GR->getOrCreateSPIRVType(Ty, MIRBuilder), GR,
-                               MIRBuilder);
+Register createVirtualRegister(
+    const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+    SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
+  return createVirtualRegister(
+      GR->getOrCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR), GR,
+      MIRBuilder);
 }
 
 // Return true if there is an opaque pointer type nested in the argument.
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 552adf2df7d17..870649879218a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -411,7 +411,9 @@ MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg);
 bool getVacantFunctionName(Module &M, std::string &Name);
 
 void setRegClassType(Register Reg, const Type *Ty, SPIRVGlobalRegistry *GR,
-                     MachineIRBuilder &MIRBuilder, bool Force = false);
+                     MachineIRBuilder &MIRBuilder,
+                     SPIRV::AccessQualifier::AccessQualifier AccessQual,
+                     bool EmitIR, bool Force = false);
 void setRegClassType(Register Reg, const MachineInstr *SpvType,
                      SPIRVGlobalRegistry *GR, MachineRegisterInfo *MRI,
                      const MachineFunction &MF, bool Force = false);
@@ -422,8 +424,9 @@ Register createVirtualRegister(const MachineInstr *SpvType,
 Register createVirtualRegister(const MachineInstr *SpvType,
                                SPIRVGlobalRegistry *GR,
                                MachineIRBuilder &MIRBuilder);
-Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
-                               MachineIRBuilder &MIRBuilder);
+Register createVirtualRegister(
+    const Type *Ty, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIRBuilder,
+    SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR);
 
 // Return true if there is an opaque pointer type nested in the argument.
 bool isNestedPointer(const Type *Ty);



More information about the llvm-commits mailing list