[llvm] [GVN] Support rnflow pattern matching and transform (PR #162259)

Madhur Amilkanthwar via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 9 00:56:48 PDT 2026


================
@@ -3339,6 +3336,302 @@ void GVNPass::assignValNumForDeadCode() {
   }
 }
 
+/// Return true if the load can be hoisted to the loop preheader (no clobber
+/// in the loop) using MemorySSA's clobbering access.
+static bool canHoistLoadWithMSSA(Loop *L, Instruction *LoadInst,
+                                 MemorySSAUpdater *MSSAU) {
+  MemoryAccess *MA = MSSAU->getMemorySSA()->getMemoryAccess(LoadInst);
+  assert(MA && "MemoryAccess expected when MemorySSA is available");
+  MemoryAccess *Clobber =
+      MSSAU->getMemorySSA()->getSkipSelfWalker()->getClobberingMemoryAccess(MA);
+  if (!Clobber || MSSAU->getMemorySSA()->isLiveOnEntryDef(Clobber))
+    return true;
+  if (!L->contains(Clobber->getBlock()))
+    return true;
+  LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - clobbered in loop by " << *Clobber
+                    << "\n");
+  return false;
+}
+
+/// Return true if the load can be hoisted to the loop preheader (no clobber
+/// in the loop) using MemoryDependenceResults.
+static bool canHoistLoadWithMD(Loop *L, LoadInst *Load,
+                               MemoryDependenceResults *MD) {
+  MemDepResult Dep = MD->getDependency(Load);
+  if (Dep.isLocal() && (Dep.isDef() || Dep.isClobber())) {
+    Instruction *DepInst = Dep.getInst();
+    if (DepInst && L->contains(DepInst->getParent())) {
+      LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - clobbered in loop by "
+                        << *DepInst << "\n");
+      return false;
+    }
+  } else if (Dep.isNonLocal()) {
+    SmallVector<NonLocalDepResult, 64> Deps;
+    MD->getNonLocalPointerDependency(Load, Deps);
+    for (const auto &NLDep : Deps) {
+      if (L->contains(NLDep.getBB()) &&
+          (NLDep.getResult().isDef() || NLDep.getResult().isClobber())) {
+        LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - clobbered in loop (block "
+                          << NLDep.getBB()->getName() << ")\n");
+        return false;
+      }
+    }
+  } else if (Dep.isUnknown()) {
+    LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - unknown memory dependence\n");
+    return false;
+  }
+  return true;
+}
+
+/// Hoist the chain of operations for the second load to preheader.
+/// In this transformation, we hoist the redundant load to the preheader,
+/// caching the first value of the iteration. This value is used to compare with
+/// the current value of the iteration and update the minimum value.
+/// The comparison is done in the loop body using the new select instruction.
+///
+/// *** Before transformation ***
+///
+///  preheader:
+///    ...
+///  loop:
+///    ...
+///    ...
+///    %val.first = load <TYPE>, ptr %ptr.first.load, align 4
+///    %min.idx.ext = sext i32 %min.idx to i64
+///    %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
+///    %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4
+///    %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4
+///    ...
+///    ...
+///    br i1 %cond, label %loop, label %exit
+///
+///    We capture <TYPE> as a part of pattern matching and then later
+///    use it in the transformation.
+///
+/// *** After transformation ***
+///
+///  preheader:
+///    %min.idx.ext = sext i32 %min.idx.ext to i64
+///    %hoist_gep1 = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
+///    %hoist_gep2 = getelementptr i8, ptr %hoist_gep1, i64 -4
+///    %hoisted_load = load <TYPE>, ptr %hoist_gep2, align 4
+///    br label %loop
+///
+///  loop:
+///    %val.first = load <TYPE>, ptr %ptr.first.load, align 4
+///    ...
+///    (new) %val.current.min = select i1 %cond, <TYPE> %hoisted_load, <TYPE>
+///    %val.current.min
+///    ...
+///    ...
+///    br i1 %cond, label %loop, label %exit
+bool GVNPass::transformMinFindingSelectPattern(
+    Loop *L, Type *LoadType, BasicBlock *Preheader, BasicBlock *BB, Value *LHS,
+    Value *LoadVal, CmpInst *Comparison, SelectInst *Select, Value *BasePtr,
+    PHINode *IndexValPhi, Value *OffsetVal) {
+
+  assert(BasePtr && "BasePtr is null");
+  assert(OffsetVal && "OffsetVal is null");
+  assert(IndexValPhi && "IndexValPhi is null");
+  AAResults *AA = VN.getAliasAnalysis();
+  assert(AA && "AA is null");
+
+  // Check if any instruction in the loop clobbers this location. Require MSSA
+  // or MD to perform the transformation.
+  bool CanHoist = false;
+  if (MSSAU)
+    CanHoist = canHoistLoadWithMSSA(L, dyn_cast<Instruction>(LoadVal), MSSAU);
+  else if (MD)
+    CanHoist = canHoistLoadWithMD(L, cast<LoadInst>(LoadVal), MD);
+
+  if (!CanHoist) {
+    LLVM_DEBUG(dbgs() << "GVN: Cannot hoist - may be clobbered by some "
+                         "instruction in the loop.\n");
+    return false;
+  }
+
+  IRBuilder<> Builder(Preheader->getTerminator());
+  Value *InitialMinIndex = IndexValPhi->getIncomingValueForBlock(Preheader);
+
+  // Insert PHI node at the top of this block.
+  // This PHI node will be used to memoize the current minimum value so far.
+  PHINode *KnownMinPhi = PHINode::Create(LoadType, 2, "known_min", BB->begin());
+
+  // Hoist the load and build the necessary operations.
+  // 1. hoist_0 = sext i32 1 to i64
+  Value *HoistedSExt =
+      Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext");
+
+  // 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt
+  Value *HoistedGEP1 =
+      Builder.CreateGEP(LoadType, BasePtr, HoistedSExt, "hoist_gep1");
+
+  // 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal
+  Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1,
+                                         OffsetVal, "hoist_gep2");
+
+  // 4. hoisted_load = load float, ptr HoistedGEP2
+  LoadInst *NewLoad = Builder.CreateLoad(LoadType, HoistedGEP2, "hoisted_load");
+
+  // Update MemorySSA before erasing the original load.
+  if (MSSAU) {
+    auto *OrigUse =
+        MSSAU->getMemorySSA()->getMemoryAccess(dyn_cast<Instruction>(LoadVal));
+    if (OrigUse) {
+      MemoryAccess *DefiningAccess = OrigUse->getDefiningAccess();
+      MSSAU->createMemoryAccessInBB(NewLoad, DefiningAccess, Preheader,
+                                    MemorySSA::BeforeTerminator);
+      MSSAU->removeMemoryAccess(OrigUse);
+    }
+  }
+
+  // Invalidate MD cache for the loaded pointer; we added a new load and removed
+  // the old one (removeInstruction handles removing the old load from MD).
+  if (MD)
+    MD->invalidateCachedPointerInfo(NewLoad->getPointerOperand());
+
+  // Let the new load now take the place of the old load.
+  LoadVal->replaceAllUsesWith(NewLoad);
+  Instruction *LoadInst = dyn_cast<Instruction>(LoadVal);
+  if (uint32_t ValNo = VN.lookup(LoadInst, false))
+    LeaderTable.erase(ValNo, LoadInst, LoadInst->getParent());
+  removeInstruction(LoadInst);
+
+  // Comparison should now compare the current value and the newly inserted
+  // PHI node.
+  Comparison->setOperand(1, KnownMinPhi);
+
+  // Create new select instruction for selecting the minimum value.
+  IRBuilder<> SelectBuilder(BB->getTerminator());
+  SelectInst *CurrentMinSelect = dyn_cast<SelectInst>(
+      SelectBuilder.CreateSelect(Comparison, LHS, KnownMinPhi, "current_min"));
+
+  // Populate the newly created PHI node
+  // with (hoisted) NewLoad from the preheader and CurrentMinSelect.
+  KnownMinPhi->addIncoming(NewLoad, Preheader);
+  KnownMinPhi->addIncoming(CurrentMinSelect, BB);
+  LLVM_DEBUG(
+      dbgs() << "GVN: Transformed the code for minimum finding pattern.\n");
+  return true;
+}
+
+/// We are looking for the following pattern:
+/// loop:
+///   ...
+///   ...
+///   %min.idx = phi i32 [ %initial_min_idx, %entry ], [ %min.idx.next, %loop ]
+///   ...
+///   %val.first = load <TYPE>, ptr %ptr.first.load, align 4
+///   %min.idx.ext = sext i32 %min.idx to i64
+///   %ptr.<TYPE>.min = getelementptr <TYPE>, ptr %0, i64 %min.idx.ext
+///   %ptr.second.load = getelementptr i8, ptr %ptr.<TYPE>.min, i64 -4
+///   %val.current.min = load <TYPE>, ptr %ptr.second.load, align 4
+///   %cmp = <CMP_INST> <TYPE> %val.first, %val.current.min
+///   ...
+///   %min.idx.next = select i1 %cmp, ..., i32 %min.idx
+///   ...
+///   ...
+///   br i1 ..., label %loop, ...
+bool GVNPass::recognizeMinFindingSelectPattern(SelectInst *Select) {
+  Value *OffsetVal = nullptr;
+  BasicBlock *BB = Select->getParent();
+
+  // If the block is not in a loop, bail out.
+  Loop *L = LI->getLoopFor(BB);
+  if (!L)
+    return false;
+
+  // If preheader of the loop is not found, bail out.
+  BasicBlock *Preheader = L->getLoopPreheader();
+  if (!Preheader) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Could not find loop preheader.\n");
+    return false;
+  }
+  Value *Condition = Select->getCondition();
+  CmpInst *Comparison = dyn_cast<CmpInst>(Condition);
+  if (!Comparison) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Condition is not a comparison.\n");
+    return false;
+  }
+
+  // Check if this is less-than comparison.
+  CmpInst::Predicate Pred = Comparison->getPredicate();
+  if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT &&
+      Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) {
+    LLVM_DEBUG(
+        dbgs() << "GVN (minindx): Not a less-than comparison, predicate: "
+               << Pred << "\n");
+    return false;
+  }
+
+  // Check that both operands are loads.
+  Value *LHS = Comparison->getOperand(0);
+  Value *RHS = Comparison->getOperand(1);
+  if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Not both operands are loads.\n");
+    return false;
+  }
+
+  // Check if the type of both loads are the same.
+  if (LHS->getType() != RHS->getType()) {
+    LLVM_DEBUG(
+        dbgs() << "GVN (minindx): Not both loads are of the same type.\n");
+    return false;
+  }
+  Type *LoadType = LHS->getType();
+  Value *InnerGEP;
+  const APInt *OffsetAPInt;
+  if (!match(RHS, m_Load(m_PtrAdd(m_Value(InnerGEP), m_APInt(OffsetAPInt))))) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Not a required load pattern.\n");
+    return false;
+  }
+  auto *TypedGEP = dyn_cast<GetElementPtrInst>(InnerGEP);
+  if (!TypedGEP) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Not a typed GEP.\n");
+    return false;
+  }
+  Type *ElemTy = TypedGEP->getSourceElementType();
+  // Check if ElemTy is same as LoadType.
+  if (ElemTy != LoadType) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Not a required element type.\n");
+    return false;
+  }
+  OffsetVal =
+      ConstantInt::get(Type::getInt64Ty(RHS->getContext()), *OffsetAPInt);
+
+  // Check if the second operand of InnerGEP is a sext instruction.
+  auto *SEInst = dyn_cast<SExtInst>(TypedGEP->getOperand(1));
+  if (!SEInst) {
+    LLVM_DEBUG(dbgs() << "GVN (minindx): Not a sext instruction.\n");
+    return false;
+  }
+
+  // Check if the "To" and "from" type of the sext instruction are i64 and i32
----------------
madhur13490 wrote:

Done.

https://github.com/llvm/llvm-project/pull/162259


More information about the llvm-commits mailing list