[llvm] [PredicateInfo] Support existing `PredicateType` by adding `PredicatePHI` when needing introduction of phi nodes (PR #151132)

Rajveer Singh Bharadwaj via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 29 05:04:29 PDT 2025


https://github.com/Rajveer100 created https://github.com/llvm/llvm-project/pull/151132

Resolves #150606

Currently `ssa.copy` is used mostly for straight line code, i.e, without joins or uses of phi nodes. With this, passes would be able to pick up the relevant info and further optimize the IR.

>From 439b98a98899878929a8317d9ee87c7839164b32 Mon Sep 17 00:00:00 2001
From: Rajveer <rajveer.developer at icloud.com>
Date: Tue, 29 Jul 2025 17:21:36 +0530
Subject: [PATCH] [PredicateInfo] Support existing `PredicateType` by adding
 `PredicatePHI` when needing introduction of phi nodes

Resolves #150606

Currently `ssa.copy` is used mostly for straight line code, i.e, without
joins or uses of phi nodes. With this, passes would be able to pick up the
relevant info and further optimize the IR.
---
 .../llvm/Transforms/Utils/PredicateInfo.h     |  14 +-
 llvm/lib/Transforms/Utils/PredicateInfo.cpp   | 189 ++++++++++++++++++
 2 files changed, 202 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
index c243e236901d5..dfaeec04daff1 100644
--- a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
+++ b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
@@ -67,7 +67,7 @@ class Value;
 class IntrinsicInst;
 class raw_ostream;
 
-enum PredicateType { PT_Branch, PT_Assume, PT_Switch };
+enum PredicateType { PT_Branch, PT_Assume, PT_Switch, PT_PHI };
 
 /// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op
 /// is the value the constraint applies to (the ssa.copy result).
@@ -171,6 +171,18 @@ class PredicateSwitch : public PredicateWithEdge {
   }
 };
 
+class PredicatePHI : public PredicateBase {
+public:
+  BasicBlock *PHIBlock;
+  SmallVector<std::pair<BasicBlock *, PredicateBase *>, 4> IncomingPredicates;
+
+  PredicatePHI(Value *Op, BasicBlock *PHIBB)
+      : PredicateBase(PT_PHI, Op, nullptr), PHIBlock(PHIBB) {}
+  static bool classof(const PredicateBase *PB) { return PB->Type == PT_PHI; }
+
+  LLVM_ABI std::optional<PredicateConstraint> getConstraint() const;
+};
+
 /// Encapsulates PredicateInfo, including all data associated with memory
 /// accesses.
 class PredicateInfo {
diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index ac413c9b58152..6e5e97e97849c 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -214,6 +214,8 @@ class PredicateInfoBuilder {
   // whether it returned a valid result.
   DenseMap<Value *, unsigned int> ValueInfoNums;
 
+  DenseMap<BasicBlock *, SmallVector<Value *, 4>> PHICandidates;
+
   BumpPtrAllocator &Allocator;
 
   ValueInfo &getOrCreateValueInfo(Value *);
@@ -225,6 +227,13 @@ class PredicateInfoBuilder {
                      SmallVectorImpl<Value *> &OpsToRename);
   void processSwitch(SwitchInst *, BasicBlock *,
                      SmallVectorImpl<Value *> &OpsToRename);
+  void identifyPHIInsertionPoints(SmallVectorImpl<Value *> &OpsToRename);
+  bool needsPHIForPredicateInfo(Value *Op, BasicBlock *BB);
+  void insertPredicatePHIs(Value *Op, BasicBlock *BB);
+  void computePredicateIDF(Value *Op,
+                           const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
+                           SmallPtrSet<BasicBlock *, 8> &PHIBlocks);
+  void processPredicatePHIs(SmallVectorImpl<Value *> &OpsToRename);
   void renameUses(SmallVectorImpl<Value *> &OpsToRename);
   void addInfoFor(SmallVectorImpl<Value *> &OpsToRename, Value *Op,
                   PredicateBase *PB);
@@ -453,6 +462,155 @@ void PredicateInfoBuilder::processSwitch(
   }
 }
 
+void PredicateInfoBuilder::identifyPHIInsertionPoints(
+    SmallVectorImpl<Value *> &OpsToRename) {
+  for (Value *Op : OpsToRename) {
+    const auto &ValueInfo = getValueInfo(Op);
+    SmallPtrSet<BasicBlock *, 8> DefiningBlocks;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        DefiningBlocks.insert(PBranch->To);
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        DefiningBlocks.insert(PSwitch->To);
+      } else if (auto *PAssume = dyn_cast<PredicateAssume>(PInfo)) {
+        DefiningBlocks.insert(PAssume->AssumeInst->getParent());
+      }
+    }
+
+    if (DefiningBlocks.size() > 1) {
+      SmallPtrSet<BasicBlock *, 8> PHIBlocks;
+      computePredicateIDF(Op, DefiningBlocks, PHIBlocks);
+
+      for (BasicBlock *PHIBlock : PHIBlocks) {
+        if (needsPHIForPredicateInfo(Op, PHIBlock)) {
+          PHICandidates[PHIBlock].push_back(Op);
+        }
+      }
+    }
+  }
+}
+
+void PredicateInfoBuilder::computePredicateIDF(
+    Value *Op, const SmallPtrSet<BasicBlock *, 8> &DefiningBlocks,
+    SmallPtrSet<BasicBlock *, 8> &PHIBlocks) {
+
+  SmallVector<BasicBlock *, 8> Worklist(DefiningBlocks.begin(),
+                                        DefiningBlocks.end());
+
+  while (!Worklist.empty()) {
+    BasicBlock *BB = Worklist.pop_back_val();
+
+    DomTreeNode *Node = DT.getNode(BB);
+    if (!Node)
+      continue;
+
+    for (BasicBlock *Succ : successors(BB)) {
+      if (!DT.dominates(BB, Succ)) {
+        BasicBlock *IDom = DT.getNode(BB)->getIDom()->getBlock();
+        if (DT.dominates(IDom, Succ)) {
+          if (PHIBlocks.insert(Succ).second) {
+            bool HasOpUse = false;
+            for (auto &I : *Succ) {
+              for (Use &U : I.uses()) {
+                if (U.get() == Op) {
+                  HasOpUse = true;
+                  break;
+                }
+              }
+            }
+            if (HasOpUse) {
+              Worklist.push_back(Succ);
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+bool PredicateInfoBuilder::needsPHIForPredicateInfo(Value *Op, BasicBlock *BB) {
+  if (BB->getSinglePredecessor())
+    return false;
+
+  const auto &ValueInfo = getValueInfo(Op);
+  SmallDenseSet<PredicateBase *, 4> PredPredicates;
+
+  for (BasicBlock *Pred : predecessors(BB)) {
+    PredicateBase *PredInfo = nullptr;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        if (PBranch->From == Pred && PBranch->To == BB) {
+          PredInfo = const_cast<PredicateBase *>(PInfo);
+          break;
+        }
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        if (PSwitch->From == Pred && PSwitch->To == BB) {
+          PredInfo = const_cast<PredicateBase *>(PInfo);
+          break;
+        }
+      }
+    }
+
+    if (PredInfo) {
+      PredPredicates.insert(PredInfo);
+    }
+  }
+
+  return PredPredicates.size() > 1 ||
+         (PredPredicates.size() == 1 && pred_size(BB) > 1);
+}
+
+void PredicateInfoBuilder::insertPredicatePHIs(Value *Op, BasicBlock *BB) {
+  IRBuilder<> Builder(&BB->front());
+  PHINode *PHI = Builder.CreatePHI(Op->getType(), pred_size(BB),
+                                   Op->getName() + ".predicate.phi");
+
+  auto *PPhi = new (Allocator) PredicatePHI(Op, BB);
+  PPhi->RenamedOp = PHI;
+
+  const auto &ValueInfo = getValueInfo(Op);
+  for (BasicBlock *Pred : predecessors(BB)) {
+    Value *IncomingValue = Op;
+
+    for (const auto *PInfo : ValueInfo.Infos) {
+      if (auto *PBranch = dyn_cast<PredicateBranch>(PInfo)) {
+        if (PBranch->From == Pred && PBranch->To == BB) {
+          PPhi->IncomingPredicates.push_back(
+              {Pred, const_cast<PredicateBase *>(PInfo)});
+          break;
+        }
+      } else if (auto *PSwitch = dyn_cast<PredicateSwitch>(PInfo)) {
+        if (PSwitch->From == Pred && PSwitch->To == BB) {
+          PPhi->IncomingPredicates.push_back(
+              {Pred, const_cast<PredicateBase *>(PInfo)});
+          break;
+        }
+      }
+    }
+
+    PHI->addIncoming(IncomingValue, Pred);
+  }
+
+  PI.PredicateMap.insert({PHI, PPhi});
+
+  auto &OperandInfo = getOrCreateValueInfo(Op);
+  OperandInfo.Infos.push_back(PPhi);
+}
+
+void PredicateInfoBuilder::processPredicatePHIs(
+    SmallVectorImpl<Value *> &OpsToRename) {
+
+  for (const auto &Entry : PHICandidates) {
+    BasicBlock *PHIBlock = Entry.first;
+
+    for (Value *Op : Entry.second) {
+      insertPredicatePHIs(Op, PHIBlock);
+    }
+  }
+}
+
 // Build predicate info for our function
 void PredicateInfoBuilder::buildPredicateInfo() {
   DT.updateDFSNumbers();
@@ -479,6 +637,10 @@ void PredicateInfoBuilder::buildPredicateInfo() {
       if (DT.isReachableFromEntry(II->getParent()))
         processAssume(II, II->getParent(), OpsToRename);
   }
+
+  identifyPHIInsertionPoints(OpsToRename);
+  processPredicatePHIs(OpsToRename);
+
   // Now rename all our operations.
   renameUses(OpsToRename);
 }
@@ -774,10 +936,33 @@ std::optional<PredicateConstraint> PredicateBase::getConstraint() const {
     }
 
     return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
+  case PT_PHI:
+    return cast<PredicatePHI>(this)->getConstraint();
   }
   llvm_unreachable("Unknown predicate type");
 }
 
+std::optional<PredicateConstraint> PredicatePHI::getConstraint() const {
+  // For PHI predicates, find the common constraint across all incoming edges
+  if (IncomingPredicates.empty())
+    return std::nullopt;
+
+  auto FirstConstraint = IncomingPredicates[0].second->getConstraint();
+  if (!FirstConstraint)
+    return std::nullopt;
+
+  // Verify all incoming predicates have the same constraint
+  for (size_t I = 1; I < IncomingPredicates.size(); ++I) {
+    auto Constraint = IncomingPredicates[I].second->getConstraint();
+    if (!Constraint || Constraint->Predicate != FirstConstraint->Predicate ||
+        Constraint->OtherOp != FirstConstraint->OtherOp) {
+      return std::nullopt;
+    }
+  }
+
+  return FirstConstraint;
+}
+
 void PredicateInfo::verifyPredicateInfo() const {}
 
 // Replace ssa_copy calls created by PredicateInfo with their operand.
@@ -839,6 +1024,10 @@ class PredicateInfoAnnotatedWriter : public AssemblyAnnotationWriter {
       } else if (const auto *PA = dyn_cast<PredicateAssume>(PI)) {
         OS << "; assume predicate info {"
            << " Comparison:" << *PA->Condition;
+      } else if (const auto *PP = dyn_cast<PredicatePHI>(PI)) {
+        OS << "; phi predicate info { PHIBlock: ";
+        PP->PHIBlock->printAsOperand(OS);
+        OS << " IncomingEdges: " << PP->IncomingPredicates.size();
       }
       OS << ", RenamedOp: ";
       PI->RenamedOp->printAsOperand(OS, false);



More information about the llvm-commits mailing list