[llvm] [SPIR-V] Fix generation of gMIR vs. SPIR-V code from utility methods (PR #128159)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 02:23:43 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

The SPIR-V Backend uses the same set of utility functions, mostly though not entirely from SPIRVGlobalRegistry, to generate gMIR and SPIR-V opcodes, depending on the current stage of translation. This is controlled by an explicit EmitIR flag rather than the current translation pass, and there are legacy pieces of code where the EmitIR flag is declared so that it has a default true value, allowing using utility functions without explicitly declaring their intent to work either in gMIR or in SPIR-V part of the lowering process. While it may be ok to leave this default EmitIR flag as is in generation of scalar integer/float types, as we don't expect to see any dependent opcodes derived from such OpTypeXXX instructions, using of EmitIR by default in aggregation types is a source of hidden logical flaws and actual issues. This PR provides a partial fix to the problem by removing default status of EmitIR, requiring a user call site to explicitly announce its intent to generate gMIR or SPIR-V code, fixes several cases of misuse of EmitIR, and, the most important, fixes a nasty logical error that breaks passing of actually asked EmitIR value by the default value in the `findSPIRVType` call. The latter error was a source of issues in the post-instruction selection pass that has been getting gMIR code where SPIR-V was explicitly requested due to overloaded with default parameters api in SPIRVGlobalRegistry (most notably, `findSPIRVType`).

---
Full diff: https://github.com/llvm/llvm-project/pull/128159.diff


6 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+9-7) 
- (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+2-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp (+4-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+18-14) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+10-12) 
- (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+3-2) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 7b897f7e34c6f..c65166902550c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -476,8 +476,9 @@ static bool buildSelectInst(MachineIRBuilder &MIRBuilder,
   if (ReturnType->getOpcode() == SPIRV::OpTypeVector) {
     unsigned Bits = GR->getScalarOrVectorBitWidth(ReturnType);
     uint64_t AllOnes = APInt::getAllOnes(Bits).getZExtValue();
-    TrueConst = GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType);
-    FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType);
+    TrueConst =
+        GR->getOrCreateConsIntVector(AllOnes, MIRBuilder, ReturnType, true);
+    FalseConst = GR->getOrCreateConsIntVector(0, MIRBuilder, ReturnType, true);
   } else {
     TrueConst = GR->buildConstantInt(1, MIRBuilder, ReturnType);
     FalseConst = GR->buildConstantInt(0, MIRBuilder, ReturnType);
@@ -1457,7 +1458,7 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       ToTruncate = DefaultReg;
     }
     auto NewRegister =
-        GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+        GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
     MIRBuilder.buildCopy(DefaultReg, NewRegister);
   } else { // If it could be in range, we need to load from the given builtin.
     auto Vec3Ty =
@@ -1492,13 +1493,14 @@ static bool genWorkgroupQuery(const SPIRV::IncomingCall *Call,
       GR->assignSPIRVTypeToVReg(BoolType, CompareRegister, MIRBuilder.getMF());
 
       // Use G_ICMP to check if idxVReg < 3.
-      MIRBuilder.buildICmp(CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
-                           GR->buildConstantInt(3, MIRBuilder, IndexType));
+      MIRBuilder.buildICmp(
+          CmpInst::ICMP_ULT, CompareRegister, IndexRegister,
+          GR->buildConstantInt(3, MIRBuilder, IndexType, true));
 
       // Get constant for the default value (0 or 1 depending on which
       // function).
       Register DefaultRegister =
-          GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType);
+          GR->buildConstantInt(DefaultValue, MIRBuilder, PointerSizeType, true);
 
       // Get a register for the selection result (possibly a new temporary one).
       Register SelectionResult = Call->ReturnRegister;
@@ -2277,7 +2279,7 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
       Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
                                            SpvFieldTy, *ST.getInstrInfo());
     } else {
-      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
+      Const = GR->buildConstantInt(0, MIRBuilder, SpvTy, true);
     }
     if (!LocalWorkSize.isValid())
       LocalWorkSize = Const;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 78f6b188c45c1..e47dfddd55975 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -669,7 +669,8 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // Make sure there's a valid return reg, even for functions returning void.
   if (!ResVReg.isValid())
     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
-  SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
+  SPIRVType *RetType = GR->assignTypeToVReg(
+      OrigRetTy, ResVReg, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, true);
 
   // Emit the call instruction and its args.
   auto MIB = MIRBuilder.buildInstr(CallOp)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
index b98cef0a4f07f..ee98af5cffe4c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitNonSemanticDI.cpp
@@ -193,7 +193,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
     };
 
     const SPIRVType *VoidTy =
-        GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder);
+        GR->getOrCreateSPIRVType(Type::getVoidTy(*Context), MIRBuilder,
+                                 SPIRV::AccessQualifier::ReadWrite, false);
 
     const auto EmitDIInstruction =
         [&](SPIRV::NonSemanticExtInst::NonSemanticExtInst Inst,
@@ -217,7 +218,8 @@ bool SPIRVEmitNonSemanticDI::emitGlobalDI(MachineFunction &MF) {
         };
 
     const SPIRVType *I32Ty =
-        GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder);
+        GR->getOrCreateSPIRVType(Type::getInt32Ty(*Context), MIRBuilder,
+                                 SPIRV::AccessQualifier::ReadWrite, false);
 
     const Register DwarfVersionReg =
         GR->buildConstantInt(DwarfVersion, MIRBuilder, I32Ty, false);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e2f1b211caa5c..a09474a21534e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -244,7 +244,8 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
     CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
     if (MIRBuilder)
-      assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
+      assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder,
+                       SPIRV::AccessQualifier::ReadWrite, true);
     else
       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
     DT.add(CI, CurMF, Res);
@@ -271,7 +272,8 @@ SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
         CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
     CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
     if (MIRBuilder)
-      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
+      assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder,
+                       SPIRV::AccessQualifier::ReadWrite, true);
     else
       assignFloatTypeToVReg(BitWidth, Res, *I, *TII);
     DT.add(CI, CurMF, Res);
@@ -878,12 +880,13 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
   });
 }
 
-SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
-                                                MachineIRBuilder &MIRBuilder,
-                                                bool EmitIR) {
+SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(
+    const StructType *Ty, MachineIRBuilder &MIRBuilder,
+    SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
   SmallVector<Register, 4> FieldTypes;
   for (const auto &Elem : Ty->elements()) {
-    SPIRVType *ElemTy = findSPIRVType(toTypedPointer(Elem), MIRBuilder);
+    SPIRVType *ElemTy =
+        findSPIRVType(toTypedPointer(Elem), MIRBuilder, AccQual, EmitIR);
     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
            "Invalid struct element type");
     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
@@ -1017,26 +1020,27 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
   if (Ty->isVoidTy())
     return getOpTypeVoid(MIRBuilder);
   if (Ty->isVectorTy()) {
-    SPIRVType *El =
-        findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
+    SPIRVType *El = findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
+                                  MIRBuilder, AccQual, EmitIR);
     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
                            MIRBuilder);
   }
   if (Ty->isArrayTy()) {
-    SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
+    SPIRVType *El =
+        findSPIRVType(Ty->getArrayElementType(), MIRBuilder, AccQual, EmitIR);
     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
   }
   if (auto SType = dyn_cast<StructType>(Ty)) {
     if (SType->isOpaque())
       return getOpTypeOpaque(SType, MIRBuilder);
-    return getOpTypeStruct(SType, MIRBuilder, EmitIR);
+    return getOpTypeStruct(SType, MIRBuilder, AccQual, EmitIR);
   }
   if (auto FType = dyn_cast<FunctionType>(Ty)) {
-    SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
+    SPIRVType *RetTy =
+        findSPIRVType(FType->getReturnType(), MIRBuilder, AccQual, EmitIR);
     SmallVector<SPIRVType *, 4> ParamTypes;
-    for (const auto &t : FType->params()) {
-      ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
-    }
+    for (const auto &ParamTy : FType->params())
+      ParamTypes.push_back(findSPIRVType(ParamTy, MIRBuilder, AccQual, EmitIR));
     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
   }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 0c94ec4df97f5..5ad705f197d5a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,13 +94,11 @@ class SPIRVGlobalRegistry {
 
   // Add a new OpTypeXXX instruction without checking for duplicates.
   SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
-                             SPIRV::AccessQualifier::AccessQualifier AQ =
-                                 SPIRV::AccessQualifier::ReadWrite,
-                             bool EmitIR = true);
+                             SPIRV::AccessQualifier::AccessQualifier AQ,
+                             bool EmitIR);
   SPIRVType *findSPIRVType(const Type *Ty, MachineIRBuilder &MIRBuilder,
-                           SPIRV::AccessQualifier::AccessQualifier accessQual =
-                               SPIRV::AccessQualifier::ReadWrite,
-                           bool EmitIR = true);
+                           SPIRV::AccessQualifier::AccessQualifier accessQual,
+                           bool EmitIR);
   SPIRVType *
   restOfCreateSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
                         SPIRV::AccessQualifier::AccessQualifier AccessQual,
@@ -321,9 +319,8 @@ class SPIRVGlobalRegistry {
   // and map it to the given VReg by creating an ASSIGN_TYPE instruction.
   SPIRVType *assignTypeToVReg(const Type *Type, Register VReg,
                               MachineIRBuilder &MIRBuilder,
-                              SPIRV::AccessQualifier::AccessQualifier AQ =
-                                  SPIRV::AccessQualifier::ReadWrite,
-                              bool EmitIR = true);
+                              SPIRV::AccessQualifier::AccessQualifier AQ,
+                              bool EmitIR);
   SPIRVType *assignIntTypeToVReg(unsigned BitWidth, Register VReg,
                                  MachineInstr &I, const SPIRVInstrInfo &TII);
   SPIRVType *assignFloatTypeToVReg(unsigned BitWidth, Register VReg,
@@ -470,13 +467,14 @@ class SPIRVGlobalRegistry {
                              MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
-                            MachineIRBuilder &MIRBuilder, bool EmitIR = true);
+                            MachineIRBuilder &MIRBuilder, bool EmitIR);
 
   SPIRVType *getOpTypeOpaque(const StructType *Ty,
                              MachineIRBuilder &MIRBuilder);
 
   SPIRVType *getOpTypeStruct(const StructType *Ty, MachineIRBuilder &MIRBuilder,
-                             bool EmitIR = true);
+                             SPIRV::AccessQualifier::AccessQualifier AccQual,
+                             bool EmitIR);
 
   SPIRVType *getOpTypePointer(SPIRV::StorageClass::StorageClass SC,
                               SPIRVType *ElemType, MachineIRBuilder &MIRBuilder,
@@ -539,7 +537,7 @@ class SPIRVGlobalRegistry {
                                     SPIRVType *SpvType,
                                     const SPIRVInstrInfo &TII);
   Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                                    SPIRVType *SpvType, bool EmitIR = true);
+                                    SPIRVType *SpvType, bool EmitIR);
   Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                    SPIRVType *SpvType);
   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index d5b81bf46c804..f622be893919f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -192,7 +192,7 @@ static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
   // Insert a bitcast before the instruction to keep SPIR-V code valid.
   LLVMContext &Context = MF->getFunction().getContext();
   SPIRVType *NewPtrType =
-      createNewPtrType(GR, I, OpType, false, true, nullptr,
+      createNewPtrType(GR, I, OpType, false, false, nullptr,
                        TargetExtType::get(Context, "spirv.Event"));
   doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
@@ -216,7 +216,8 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
   MachineIRBuilder MIB(I);
   LLVMContext &Context = MF->getFunction().getContext();
   SPIRVType *ElemType =
-      GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB);
+      GR.getOrCreateSPIRVType(IntegerType::getInt8Ty(Context), MIB,
+                              SPIRV::AccessQualifier::ReadWrite, false);
   SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ElemType, MIB, SC);
   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/128159


More information about the llvm-commits mailing list