[llvm] [SPIR-V] Fix issue #120078 and simplifies parsing of floating point decoration tips in demangled function name (PR #120128)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 11:19:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

<details>
<summary>Changes</summary>

This PR fixes an issue https://github.com/llvm/llvm-project/issues/120078 and improves/simplifies parsing of demangled function name that aims to detect a tip for floating point decorations. The latter improvement fixes also a complaint from `LLVM_USE_SANITIZER=Address`.

---
Full diff: https://github.com/llvm/llvm-project/pull/120128.diff


8 Files Affected:

- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+10-6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.h (+1-1) 
- (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+26-26) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+30-6) 
- (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+4) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+7-2) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+13) 
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+17) 


``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 4bfa51e2cccdd8..e236d646e66fc1 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -174,7 +174,7 @@ using namespace InstructionSet;
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
 std::string lookupBuiltinNameHelper(StringRef DemangledCall,
-                                    std::string *Postfix) {
+                                    FPDecorationId *DecorationId) {
   const static std::string PassPrefix = "(anonymous namespace)::";
   std::string BuiltinName;
   // Itanium Demangler result may have "(anonymous namespace)::" prefix
@@ -232,12 +232,16 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
       "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
       "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
       "Convert|"
-      "UConvert|SConvert|FConvert|SatConvert).*)_R(.*)");
+      "UConvert|SConvert|FConvert|SatConvert).*)_R[^_]*_?(\\w+)?.*");
   std::smatch Match;
-  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 3) {
-    BuiltinName = Match[1].str();
-    if (Postfix)
-      *Postfix = Match[3].str();
+  if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
+    std::ssub_match SubMatch;
+    if (DecorationId && Match.size() > 3) {
+      SubMatch = Match[3];
+      *DecorationId = demangledPostfixToDecorationId(SubMatch.str());
+    }
+    SubMatch = Match[1];
+    BuiltinName = SubMatch.str();
   }
 
   return BuiltinName;
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
index 0182d9652d18c9..1a8641a8328ddd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.h
@@ -21,7 +21,7 @@ namespace llvm {
 namespace SPIRV {
 /// Parses the name part of the demangled builtin call.
 std::string lookupBuiltinNameHelper(StringRef DemangledCall,
-                                    std::string *Postfix = nullptr);
+                                    FPDecorationId *DecorationId = nullptr);
 /// Lowers a builtin function call using the provided \p DemangledCall skeleton
 /// and external instruction \p Set.
 ///
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 433956f44917fb..77b54219a9acc4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -1876,18 +1876,6 @@ bool SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
   return true;
 }
 
-static unsigned roundingModeMDToDecorationConst(StringRef S) {
-  if (S == "rte")
-    return SPIRV::FPRoundingMode::FPRoundingMode::RTE;
-  if (S == "rtz")
-    return SPIRV::FPRoundingMode::FPRoundingMode::RTZ;
-  if (S == "rtp")
-    return SPIRV::FPRoundingMode::FPRoundingMode::RTP;
-  if (S == "rtn")
-    return SPIRV::FPRoundingMode::FPRoundingMode::RTN;
-  return std::numeric_limits<unsigned>::max();
-}
-
 void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
                                                 IRBuilder<> &B) {
   // TODO: extend the list of functions with known result types
@@ -1905,9 +1893,10 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
       Function *CalledF = CI->getCalledFunction();
       std::string DemangledName =
           getOclOrSpirvBuiltinDemangledName(CalledF->getName());
-      std::string Postfix;
+      FPDecorationId DecorationId = FPDecorationId::NONE;
       if (DemangledName.length() > 0)
-        DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName, &Postfix);
+        DemangledName =
+            SPIRV::lookupBuiltinNameHelper(DemangledName, &DecorationId);
       auto ResIt = ResTypeWellKnown.find(DemangledName);
       if (ResIt != ResTypeWellKnown.end()) {
         IsKnown = true;
@@ -1919,18 +1908,29 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
           break;
         }
       }
-      // check if a floating rounding mode info is present
-      StringRef S = Postfix;
-      SmallVector<StringRef, 8> Parts;
-      S.split(Parts, "_", -1, false);
-      if (Parts.size() > 1) {
-        // Convert the info about rounding mode into a decoration record.
-        unsigned RoundingModeDeco = roundingModeMDToDecorationConst(Parts[1]);
-        if (RoundingModeDeco != std::numeric_limits<unsigned>::max())
-          createRoundingModeDecoration(CI, RoundingModeDeco, B);
-        // Check if the SaturatedConversion info is present.
-        if (Parts[1] == "sat")
-          createSaturatedConversionDecoration(CI, B);
+      // check if a floating rounding mode or saturation info is present
+      switch (DecorationId) {
+      default:
+        break;
+      case FPDecorationId::SAT:
+        createSaturatedConversionDecoration(CI, B);
+        break;
+      case FPDecorationId::RTE:
+        createRoundingModeDecoration(
+            CI, SPIRV::FPRoundingMode::FPRoundingMode::RTE, B);
+        break;
+      case FPDecorationId::RTZ:
+        createRoundingModeDecoration(
+            CI, SPIRV::FPRoundingMode::FPRoundingMode::RTZ, B);
+        break;
+      case FPDecorationId::RTP:
+        createRoundingModeDecoration(
+            CI, SPIRV::FPRoundingMode::FPRoundingMode::RTP, B);
+        break;
+      case FPDecorationId::RTN:
+        createRoundingModeDecoration(
+            CI, SPIRV::FPRoundingMode::FPRoundingMode::RTN, B);
+        break;
       }
     }
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 3e913646d57c80..0c424477001062 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -157,28 +157,52 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
   });
 }
 
+void SPIRVGlobalRegistry::invalidateMachineInstr(MachineInstr *MI) {
+  // TODO:
+  // - take into account duplicate tracker case which is a known issue,
+  // - review other data structure wrt. possible issues related to removal
+  //   of a machine instruction during instruction selection.
+  const MachineFunction *MF = MI->getParent()->getParent();
+  auto It = LastInsertedTypeMap.find(MF);
+  if (It == LastInsertedTypeMap.end())
+    return;
+  if (It->second == MI)
+    LastInsertedTypeMap.erase(MF);
+}
+
 SPIRVType *SPIRVGlobalRegistry::createOpType(
     MachineIRBuilder &MIRBuilder,
     std::function<MachineInstr *(MachineIRBuilder &)> Op) {
   auto oldInsertPoint = MIRBuilder.getInsertPt();
   MachineBasicBlock *OldMBB = &MIRBuilder.getMBB();
+  MachineBasicBlock *NewMBB = &*MIRBuilder.getMF().begin();
 
   auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
   if (LastInsertedType != LastInsertedTypeMap.end()) {
     auto It = LastInsertedType->second->getIterator();
-    auto NewMBB = MIRBuilder.getMF().begin();
-    MIRBuilder.setInsertPt(*NewMBB, It->getNextNode()
-                                        ? It->getNextNode()->getIterator()
-                                        : NewMBB->end());
+    // It might happen that this instruction was removed from the first MBB,
+    // hence the Parent's check.
+    MachineBasicBlock::iterator InsertAt;
+    if (It->getParent() != NewMBB)
+      InsertAt = oldInsertPoint->getParent() == NewMBB
+                     ? oldInsertPoint
+                     : getInsertPtValidEnd(NewMBB);
+    else if (It->getNextNode())
+      InsertAt = It->getNextNode()->getIterator();
+    else
+      InsertAt = getInsertPtValidEnd(NewMBB);
+    MIRBuilder.setInsertPt(*NewMBB, InsertAt);
   } else {
-    MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
-                           MIRBuilder.getMF().begin()->begin());
+    MIRBuilder.setInsertPt(*NewMBB, NewMBB->begin());
     auto Result = LastInsertedTypeMap.try_emplace(CurMF, nullptr);
     assert(Result.second);
     LastInsertedType = Result.first;
   }
 
   MachineInstr *Type = Op(MIRBuilder);
+  // We expect all users of this function to insert definitions at the insertion
+  // point set above that is always the first MBB.
+  assert(Type->getParent() == NewMBB);
   LastInsertedType->second = Type;
 
   MIRBuilder.setInsertPt(*OldMBB, oldInsertPoint);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index df92325ed19802..ec2386fa1e56e2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -444,6 +444,10 @@ class SPIRVGlobalRegistry {
   bool isBitcastCompatible(const SPIRVType *Type1,
                            const SPIRVType *Type2) const;
 
+  // Informs about removal of the machine instruction and invalidates data
+  // structures referring this instruction.
+  void invalidateMachineInstr(MachineInstr *MI);
+
 private:
   SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ff8d1f7485e160..eef7cefdeed4c5 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -431,6 +431,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
       }
       MRI->setRegClass(SrcReg, MRI->getRegClass(DstReg));
       MRI->replaceRegWith(SrcReg, DstReg);
+      GR.invalidateMachineInstr(&I);
       I.removeFromParent();
       return true;
     } else if (I.getNumDefs() == 1) {
@@ -445,6 +446,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     // erase it
     LLVM_DEBUG(dbgs() << "Instruction is folded and dead.\n");
     salvageDebugInfo(*MRI, I);
+    GR.invalidateMachineInstr(&I);
     I.eraseFromParent();
     return true;
   }
@@ -464,6 +466,7 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
     if (HasDefs) // Make all vregs 64 bits (for SPIR-V IDs).
       for (unsigned i = 0; i < I.getNumDefs(); ++i)
         MRI->setType(I.getOperand(i).getReg(), LLT::scalar(64));
+    GR.invalidateMachineInstr(&I);
     I.removeFromParent();
     return true;
   }
@@ -2253,8 +2256,10 @@ bool SPIRVInstructionSelector::selectDiscard(Register ResVReg,
   } else {
     Opcode = SPIRV::OpKill;
     // OpKill must be the last operation of any basic block.
-    MachineInstr *NextI = I.getNextNode();
-    NextI->removeFromParent();
+    if (MachineInstr *NextI = I.getNextNode()) {
+      GR.invalidateMachineInstr(NextI);
+      NextI->removeFromParent();
+    }
   }
 
   MachineBasicBlock &BB = *I.getParent();
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index ce90e335fe404c..ddc66f98829a96 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -194,6 +194,19 @@ MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I) {
   return It;
 }
 
+MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB) {
+  MachineBasicBlock::iterator I = MBB->end();
+  if (I == MBB->begin())
+    return I;
+  --I;
+  while (I->isTerminator() || I->isDebugValue()) {
+    if (I == MBB->begin())
+      break;
+    --I;
+  }
+  return I;
+}
+
 SPIRV::StorageClass::StorageClass
 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) {
   switch (AddrSpace) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index cc77e0afa275a4..19b9be354b5b3f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -150,6 +150,10 @@ void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder,
 // i.e., at the beginning of the first block of the function.
 MachineBasicBlock::iterator getOpVariableMBBIt(MachineInstr &I);
 
+// Return a valid position for the instruction at the end of the block before
+// terminators and debug instructions.
+MachineBasicBlock::iterator getInsertPtValidEnd(MachineBasicBlock *MBB);
+
 // Convert a SPIR-V storage class to the corresponding LLVM IR address space.
 // TODO: maybe the following two functions should be handled in the subtarget
 // to allow for different OpenCL vs Vulkan handling.
@@ -396,5 +400,18 @@ Register createVirtualRegister(const Type *Ty, SPIRVGlobalRegistry *GR,
 // Return true if there is an opaque pointer type nested in the argument.
 bool isNestedPointer(const Type *Ty);
 
+enum FPDecorationId { NONE, RTE, RTZ, RTP, RTN, SAT };
+
+inline FPDecorationId demangledPostfixToDecorationId(const std::string& S) {
+  static std::unordered_map<std::string, FPDecorationId> Mapping = {
+      {"rte", FPDecorationId::RTE},
+      {"rtz", FPDecorationId::RTZ},
+      {"rtp", FPDecorationId::RTP},
+      {"rtn", FPDecorationId::RTN},
+      {"sat", FPDecorationId::SAT}};
+  auto It = Mapping.find(S);
+  return It == Mapping.end() ? FPDecorationId::NONE : It->second;
+}
+
 } // namespace llvm
 #endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

``````````

</details>


https://github.com/llvm/llvm-project/pull/120128


More information about the llvm-commits mailing list