[llvm] [PredicateInfo] Support existing `PredicateType` by adding `PredicatePHI` when needing introduction of phi nodes (PR #151132)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 29 05:05:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Rajveer Singh Bharadwaj (Rajveer100)

<details>
<summary>Changes</summary>

Resolves #<!-- -->150606

Currently `ssa.copy` is used mostly for straight line code, i.e, without joins or uses of phi nodes. With this, passes would be able to pick up the relevant info and further optimize the IR.

---
Full diff: https://github.com/llvm/llvm-project/pull/151132.diff


2 Files Affected:

- (modified) llvm/include/llvm/Transforms/Utils/PredicateInfo.h (+13-1) 
- (modified) llvm/lib/Transforms/Utils/PredicateInfo.cpp (+189) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
index c243e236901d5..dfaeec04daff1 100644
--- a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
+++ b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
@@ -67,7 +67,7 @@ class Value;
 class IntrinsicInst;
 class raw_ostream;
 
-enum PredicateType { PT_Branch, PT_Assume, PT_Switch };
+enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_PHI };
 
 /// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op
 /// is the value the constraint applies to (the ssa.copy result).
@@ -171,6 +171,18 @@ class PredicateSwitch : public PredicateWithEdge {
   }
 };
 
+class PredicatePHI : public PredicateBase {
+public:
+  BasicBlock *PHIBlock;
+  SmallVector<std::pair<BasicBlock *, PredicateBase *>, 4> IncomingPredicates;
+
+  PredicatePHI(Value *Op, BasicBlock *PHIBB)
+      : PredicateBase(PT_PHI, Op, nullptr), PHIBlock(PHIBB) {}
+  static bool classof(const PredicateBase *PB) { return PB->Type == PT_PHI; }
+
+  LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
+};
+
 /// Encapsulates PredicateInfo, including all data associated with memory
 /// accesses.
 class PredicateInfo {
diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index ac413c9b58152..6e5e97e97849c 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -214,6 +214,8 @@ class PredicateInfoBuilder {
   // whether it returned a valid result.
   DenseMap<Value *, unsigned int> ValueInfoNums;
 
+  DenseMap<BasicBlock *, SmallVector<Value *, 4>> PHICandidates;
+
   BumpPtrAllocator &Allocator;
 
   ValueInfo &getOrCreateValueInfo(Value *);
@@ -225,6 +227,13 @@ class PredicateInfoBuilder {
                      SmallVectorImpl<Value *> &OpsToRename);
   void processSwitch(SwitchInst *, BasicBlock *,
                      SmallVectorImpl<Value *> &OpsToRename);
+  void identifyPHIInsertionPoints(SmallVectorImpl<Value *> &OpsToRename);
+  bool needsPHIForPredicateInfo(Value *Op, BasicBlock *BB);
+  void insertPredicatePHIs(Value *Op, BasicBlock *BB);
+  void computePredicateIDF(Value *Op,
+                           const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
+                           SmallPtrSet<BasicBlock *, 8> &PHIBlocks);
+  void processPredicatePHIs(SmallVectorImpl<Value *> &OpsToRename);
   void renameUses(SmallVectorImpl<Value *> &OpsToRename);
   void addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op,
                   PredicateBase *PB);
@@ -453,6 +462,155 @@ void PredicateInfoBuilder::processSwitch(
   }
 }
 
+void PredicateInfoBuilder::identifyPHIInsertionPoints(
+    SmallVectorImpl<Value *> &OpsToRename) {
+  for (Value *Op : OpsToRename) {
+    const auto &ValueInfo = getValueInfo(Op);
+    SmallPtrSet<BasicBlock *, 8> DefiningBlocks;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        DefiningBlocks.insert(PBranch->To);
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        DefiningBlocks.insert(PSwitch->To);
+      } else if (auto *PAssume = dyn_cast<PredicateAssume>(PInfo)) {
+        DefiningBlocks.insert(PAssume->AssumeInst->getParent());
+      }
+    }
+
+    if (DefiningBlocks.size() > 1) {
+      SmallPtrSet<BasicBlock *, 8> PHIBlocks;
+      computePredicateIDF(Op, DefiningBlocks, PHIBlocks);
+
+      for (BasicBlock *PHIBlock : PHIBlocks) {
+        if (needsPHIForPredicateInfo(Op, PHIBlock)) {
+          PHICandidates[PHIBlock].push_back(Op);
+        }
+      }
+    }
+  }
+}
+
+void PredicateInfoBuilder::computePredicateIDF(
+    Value *Op, const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
+    SmallPtrSet<BasicBlock *, 8> &PHIBlocks) {
+
+  SmallVector<BasicBlock *, 8> Worklist(DefiningBlocks.begin(),
+                                        DefiningBlocks.end());
+
+  while (!Worklist.empty()) {
+    BasicBlock *BB = Worklist.pop_back_val();
+
+    DomTreeNode *Node = DT.getNode(BB);
+    if (!Node)
+      continue;
+
+    for (BasicBlock *Succ : successors(BB)) {
+      if (!DT.dominates(BB, Succ)) {
+        BasicBlock *IDom = DT.getNode(BB)->getIDom()->getBlock();
+        if (DT.dominates(IDom, Succ)) {
+          if (PHIBlocks.insert(Succ).second) {
+            bool HasOpUse = false;
+            for (auto &I : *Succ) {
+              for (Use &U : I.uses()) {
+                if (U.get() == Op) {
+                  HasOpUse = true;
+                  break;
+                }
+              }
+            }
+            if (HasOpUse) {
+              Worklist.push_back(Succ);
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+bool PredicateInfoBuilder::needsPHIForPredicateInfo(Value *Op, BasicBlock *BB) {
+  if (BB->getSinglePredecessor())
+    return false;
+
+  const auto &ValueInfo = getValueInfo(Op);
+  SmallDenseSet<PredicateBase *, 4> PredPredicates;
+
+  for (BasicBlock *Pred : predecessors(BB)) {
+    PredicateBase *PredInfo = nullptr;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        if (PBranch->From == Pred && PBranch->To == BB) {
+          PredInfo = const_cast<PredicateBase *>(PInfo);
+          break;
+        }
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        if (PSwitch->From == Pred && PSwitch->To == BB) {
+          PredInfo = const_cast<PredicateBase *>(PInfo);
+          break;
+        }
+      }
+    }
+
+    if (PredInfo) {
+      PredPredicates.insert(PredInfo);
+    }
+  }
+
+  return PredPredicates.size() > 1 ||
+         (PredPredicates.size() == 1 && pred_size(BB) > 1);
+}
+
+void PredicateInfoBuilder::insertPredicatePHIs(Value *Op, BasicBlock *BB) {
+  IRBuilder<> Builder(&BB->front());
+  PHINode *PHI = Builder.CreatePHI(Op->getType(), pred_size(BB),
+                                   Op->getName() + ".predicate.phi");
+
+  auto *PPhi = new (Allocator) PredicatePHI(Op, BB);
+  PPhi->RenamedOp = PHI;
+
+  const auto &ValueInfo = getValueInfo(Op);
+  for (BasicBlock *Pred : predecessors(BB)) {
+    Value *IncomingValue = Op;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        if (PBranch->From == Pred && PBranch->To == BB) {
+          PPhi->IncomingPredicates.push_back(
+              {Pred, const_cast<PredicateBase *>(PInfo)});
+          break;
+        }
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        if (PSwitch->From == Pred && PSwitch->To == BB) {
+          PPhi->IncomingPredicates.push_back(
+              {Pred, const_cast<PredicateBase *>(PInfo)});
+          break;
+        }
+      }
+    }
+
+    PHI->addIncoming(IncomingValue, Pred);
+  }
+
+  PI.PredicateMap.insert({PHI, PPhi});
+
+  auto &OperandInfo = getOrCreateValueInfo(Op);
+  OperandInfo.Infos.push_back(PPhi);
+}
+
+void PredicateInfoBuilder::processPredicatePHIs(
+    SmallVectorImpl<Value *> &OpsToRename) {
+
+  for (const auto &Entry : PHICandidates) {
+    BasicBlock *PHIBlock = Entry.first;
+
+    for (Value *Op : Entry.second) {
+      insertPredicatePHIs(Op, PHIBlock);
+    }
+  }
+}
+
 // Build predicate info for our function
 void PredicateInfoBuilder::buildPredicateInfo() {
   DT.updateDFSNumbers();
@@ -479,6 +637,10 @@ void PredicateInfoBuilder::buildPredicateInfo() {
       if (DT.isReachableFromEntry(II->getParent()))
         processAssume(II, II->getParent(), OpsToRename);
   }
+
+  identifyPHIInsertionPoints(OpsToRename);
+  processPredicatePHIs(OpsToRename);
+
   // Now rename all our operations.
   renameUses(OpsToRename);
 }
@@ -774,10 +936,33 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
     }
 
     return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
+  case PT_PHI:
+    return cast<PredicatePHI>(this)->getConstraint();
   }
   llvm_unreachable("Unknown predicate type");
 }
 
+std::optional<PredicateConstraint> PredicatePHI::getConstraint() const {
+  // For PHI predicates, find the common constraint across all incoming edges
+  if (IncomingPredicates.empty())
+    return std::nullopt;
+
+  auto FirstConstraint = IncomingPredicates[0].second->getConstraint();
+  if (!FirstConstraint)
+    return std::nullopt;
+
+  // Verify all incoming predicates have the same constraint
+  for (size_t I = 1; I < IncomingPredicates.size(); ++I) {
+    auto Constraint = IncomingPredicates[I].second->getConstraint();
+    if (!Constraint || Constraint->Predicate != FirstConstraint->Predicate ||
+        Constraint->OtherOp != FirstConstraint->OtherOp) {
+      return std::nullopt;
+    }
+  }
+
+  return FirstConstraint;
+}
+
 void PredicateInfo::verifyPredicateInfo() const {}
 
 // Replace ssa_copy calls created by PredicateInfo with their operand.
@@ -839,6 +1024,10 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
       } else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) {
         OS << "; assume predicate info {"
            << " Comparison:" << *PA->Condition;
+      } else if (const auto *PP = dyn_cast<PredicatePHI>(PI)) {
+        OS << "; phi predicate info { PHIBlock: ";
+        PP->PHIBlock->printAsOperand(OS);
+        OS << " IncomingEdges: " << PP->IncomingPredicates.size();
       }
       OS << ", RenamedOp: ";
       PI->RenamedOp->printAsOperand(OS, false);

``````````

</details>


https://github.com/llvm/llvm-project/pull/151132


More information about the llvm-commits mailing list