[llvm] [GVN] Support rnflow pattern matching and transform (PR #162259)
Madhur Amilkanthwar via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 7 03:19:13 PDT 2025
https://github.com/madhur13490 created https://github.com/llvm/llvm-project/pull/162259
None
>From 2499cef2b445c9a4772b4376bcf77bae69f6a20e Mon Sep 17 00:00:00 2001
From: Madhur Amilkanthwar <madhura at nvidia.com>
Date: Fri, 15 Aug 2025 00:34:49 -0700
Subject: [PATCH] [GVN] Support rnflow pattern matching and transform
---
llvm/include/llvm/Transforms/Scalar/GVN.h | 4 +
llvm/lib/Transforms/Scalar/GVN.cpp | 122 ++++++++++++++++++
.../test/Transforms/GVN/PRE/rnflow-gvn-pre.ll | 59 +++++++++
3 files changed, 185 insertions(+)
create mode 100644 llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll
diff --git a/llvm/include/llvm/Transforms/Scalar/GVN.h b/llvm/include/llvm/Transforms/Scalar/GVN.h
index 74a4d6ce00fcc..a73d17b0680de 100644
--- a/llvm/include/llvm/Transforms/Scalar/GVN.h
+++ b/llvm/include/llvm/Transforms/Scalar/GVN.h
@@ -22,6 +22,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/Analysis/LoopInfo.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/Compiler.h"
@@ -44,6 +45,7 @@ class FunctionPass;
class GetElementPtrInst;
class ImplicitControlFlowTracking;
class LoadInst;
+class SelectInst;
class LoopInfo;
class MemDepResult;
class MemoryAccess;
@@ -409,6 +411,8 @@ class GVNPass : public PassInfoMixin<GVNPass> {
void addDeadBlock(BasicBlock *BB);
void assignValNumForDeadCode();
void assignBlockRPONumber(Function &F);
+
+ bool optimizeMinMaxFindingSelectPattern(SelectInst *Select);
};
/// Create a legacy GVN pass.
diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp
index b9b5b5823d780..76653e1a01eae 100644
--- a/llvm/lib/Transforms/Scalar/GVN.cpp
+++ b/llvm/lib/Transforms/Scalar/GVN.cpp
@@ -2818,6 +2818,10 @@ bool GVNPass::processInstruction(Instruction *I) {
}
return Changed;
}
+ if (SelectInst *Select = dyn_cast<SelectInst>(I)) {
+ if (optimizeMinMaxFindingSelectPattern(Select))
+ return true;
+ }
// Instructions with void type don't return a value, so there's
// no point in trying to find redundancies in them.
@@ -3410,6 +3414,124 @@ void GVNPass::assignValNumForDeadCode() {
}
}
+bool GVNPass::optimizeMinMaxFindingSelectPattern(SelectInst *Select) {
+ LLVM_DEBUG(
+ dbgs()
+ << "GVN: Analyzing select instruction for minimum finding pattern\n");
+ LLVM_DEBUG(dbgs() << "GVN: Select: " << *Select << "\n");
+ Value *Condition = Select->getCondition();
+ CmpInst *Comparison = dyn_cast<CmpInst>(Condition);
+ if (!Comparison) {
+ LLVM_DEBUG(dbgs() << "GVN: Condition is not a comparison\n");
+ return false;
+ }
+
+ // Check if this is ULT comparison.
+ CmpInst::Predicate Pred = Comparison->getPredicate();
+ if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT &&
+ Pred != CmpInst::FCMP_OLT && Pred != CmpInst::FCMP_ULT) {
+ LLVM_DEBUG(dbgs() << "GVN: Not a less-than comparison, predicate: " << Pred
+ << "\n");
+ return false;
+ }
+
+ // Check that both operands are loads.
+ Value *LHS = Comparison->getOperand(0);
+ Value *RHS = Comparison->getOperand(1);
+ if (!isa<LoadInst>(LHS) || !isa<LoadInst>(RHS)) {
+ LLVM_DEBUG(dbgs() << "GVN: Not both operands are loads\n");
+ return false;
+ }
+
+ LLVM_DEBUG(dbgs() << "GVN: Found minimum finding pattern in Block: "
+ << Select->getParent()->getName() << "\n");
+
+ // Transform the pattern.
+ // Hoist the chain of operations for the second load to preheader.
+ // Get predecessor of the block containing the select instruction.
+ BasicBlock *BB = Select->getParent();
+
+ // Get preheader of the loop.
+ Loop *L = LI->getLoopFor(BB);
+ if (!L) {
+ LLVM_DEBUG(dbgs() << "GVN: Could not find loop\n");
+ return false;
+ }
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) {
+ LLVM_DEBUG(dbgs() << "GVN: Could not find loop preheader\n");
+ return false;
+ }
+
+ // Hoist the chain of operations for the second load to preheader.
+ // %90 = sext i32 %.05.i to i64
+ // %91 = getelementptr float, ptr %0, i64 %90 ; %0 + (sext i32 %85 to i64)*4
+ // %92 = getelementptr i8, ptr %91, i64 -4 ; %0 + (sext i32 %85 to i64)*4 - 4
+ // %93 = load float, ptr %92, align 4
+
+ Value *BasePtr = nullptr, *IndexVal = nullptr, *OffsetVal = nullptr;
+ IRBuilder<> Builder(Preheader->getTerminator());
+ if (match(RHS,
+ m_Load(m_GEP(m_GEP(m_Value(BasePtr), m_SExt(m_Value(IndexVal))),
+ m_Value(OffsetVal))))) {
+ LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << *RHS << "\n");
+ LLVM_DEBUG(dbgs() << "GVN: Found pattern: " << "\n");
+
+ PHINode *Phi = dyn_cast<PHINode>(IndexVal);
+ if (!Phi) {
+ LLVM_DEBUG(dbgs() << "GVN: IndexVal is not a PHI node\n");
+ return false;
+ }
+ Value *InitialMinIndex = Phi->getIncomingValueForBlock(Preheader);
+
+ // Insert PHI node at the top of this block.
+ PHINode *KnownMinPhi =
+ PHINode::Create(Builder.getFloatTy(), 2, "known_min", BB->begin());
+
+ // Build the GEP chain in the preheader.
+ // 1. hoist_0 = sext i32 to i64
+ Value *HoistedSExt =
+ Builder.CreateSExt(InitialMinIndex, Builder.getInt64Ty(), "hoist_sext");
+
+ // 2. hoist_gep1 = getelementptr float, ptr BasePtr, i64 HoistedSExt
+ Value *HoistedGEP1 = Builder.CreateGEP(Builder.getFloatTy(), BasePtr,
+ HoistedSExt, "hoist_gep1");
+
+ // 3. hoist_gep2 = getelementptr i8, ptr HoistedGEP1, i64 OffsetVal
+ Value *HoistedGEP2 = Builder.CreateGEP(Builder.getInt8Ty(), HoistedGEP1,
+ OffsetVal, "hoist_gep2");
+
+ // 4. hoisted_load = load float, ptr HoistedGEP2
+ LoadInst *NewLoad =
+ Builder.CreateLoad(Builder.getFloatTy(), HoistedGEP2, "hoisted_load");
+
+ // Replace all uses of load with new load.
+ RHS->replaceAllUsesWith(NewLoad);
+ dyn_cast<LoadInst>(RHS)->eraseFromParent();
+
+ // Replace second operand of comparison with KnownMinPhi.
+ Comparison->setOperand(1, KnownMinPhi);
+
+ // Create new select instruction for selecting the minimum value.
+ IRBuilder<> SelectBuilder(BB->getTerminator());
+ SelectInst *CurrentMinSelect =
+ dyn_cast<SelectInst>(SelectBuilder.CreateSelect(
+ Comparison, LHS, KnownMinPhi, "current_min"));
+
+ // Populate PHI node.
+ KnownMinPhi->addIncoming(NewLoad, Preheader);
+ KnownMinPhi->addIncoming(CurrentMinSelect, BB);
+ std::cout << "Transformed the code\n";
+ return true;
+ } else {
+ LLVM_DEBUG(dbgs() << "GVN: Could not find pattern: " << *RHS << "\n");
+ std::cout << "GVN: Could not find pattern: " << "\n";
+ return false;
+ }
+ return false;
+}
+
+
class llvm::gvn::GVNLegacyPass : public FunctionPass {
public:
static char ID; // Pass identification, replacement for typeid.
diff --git a/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll b/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll
new file mode 100644
index 0000000000000..6f17d4ab30240
--- /dev/null
+++ b/llvm/test/Transforms/GVN/PRE/rnflow-gvn-pre.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; Minimal test case containing only the .lr.ph.i basic block
+; RUN: opt -passes=gvn -S < %s | FileCheck %s
+
+define void @test_lr_ph_i(ptr %0) {
+; CHECK-LABEL: define void @test_lr_ph_i(
+; CHECK-SAME: ptr [[TMP0:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*]]:
+; CHECK-NEXT: [[HOIST_GEP1:%.*]] = getelementptr float, ptr [[TMP0]], i64 1
+; CHECK-NEXT: [[HOIST_GEP2:%.*]] = getelementptr i8, ptr [[HOIST_GEP1]], i64 -4
+; CHECK-NEXT: [[HOISTED_LOAD:%.*]] = load float, ptr [[HOIST_GEP2]], align 4
+; CHECK-NEXT: br label %[[DOTLR_PH_I:.*]]
+; CHECK: [[_LR_PH_I:.*:]]
+; CHECK-NEXT: [[KNOWN_MIN:%.*]] = phi float [ [[HOISTED_LOAD]], %[[ENTRY]] ], [ [[CURRENT_MIN:%.*]], %[[DOTLR_PH_I]] ]
+; CHECK-NEXT: [[INDVARS_IV_I:%.*]] = phi i64 [ 1, %[[ENTRY]] ], [ [[INDVARS_IV_NEXT_I:%.*]], %[[DOTLR_PH_I]] ]
+; CHECK-NEXT: [[TMP1:%.*]] = phi i64 [ 0, %[[ENTRY]] ], [ [[TMP10:%.*]], %[[DOTLR_PH_I]] ]
+; CHECK-NEXT: [[DOT05_I:%.*]] = phi i32 [ 1, %[[ENTRY]] ], [ [[DOT1_I:%.*]], %[[DOTLR_PH_I]] ]
+; CHECK-NEXT: [[INDVARS_IV_NEXT_I]] = add nsw i64 [[INDVARS_IV_I]], -1
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[INDVARS_IV_I]]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP2]], i64 -8
+; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[TMP3]], align 4
+; CHECK-NEXT: [[TMP5:%.*]] = sext i32 [[DOT05_I]] to i64
+; CHECK-NEXT: [[TMP6:%.*]] = getelementptr float, ptr [[TMP0]], i64 [[TMP5]]
+; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i64 -4
+; CHECK-NEXT: [[TMP8:%.*]] = fcmp contract olt float [[TMP4]], [[KNOWN_MIN]]
+; CHECK-NEXT: [[TMP9:%.*]] = trunc nsw i64 [[INDVARS_IV_NEXT_I]] to i32
+; CHECK-NEXT: [[DOT1_I]] = select i1 [[TMP8]], i32 [[TMP9]], i32 [[DOT05_I]]
+; CHECK-NEXT: [[TMP10]] = add nsw i64 [[TMP1]], -1
+; CHECK-NEXT: [[TMP11:%.*]] = icmp samesign ugt i64 [[TMP1]], 1
+; CHECK-NEXT: [[CURRENT_MIN]] = select i1 [[TMP8]], float [[TMP4]], float [[KNOWN_MIN]]
+; CHECK-NEXT: br i1 [[TMP11]], label %[[DOTLR_PH_I]], label %[[EXIT:.*]]
+; CHECK: [[EXIT]]:
+; CHECK-NEXT: ret void
+;
+entry:
+ br label %.lr.ph.i
+
+.lr.ph.i: ; preds = %.lr.ph.i, %entry
+ %indvars.iv.i = phi i64 [ 1, %entry ], [ %indvars.iv.next.i, %.lr.ph.i ]
+ %86 = phi i64 [ 0, %entry ], [ %96, %.lr.ph.i ]
+ %.05.i = phi i32 [ 1, %entry ], [ %.1.i, %.lr.ph.i ]
+ %indvars.iv.next.i = add nsw i64 %indvars.iv.i, -1
+ %87 = getelementptr float, ptr %0, i64 %indvars.iv.i
+ %88 = getelementptr i8, ptr %87, i64 -8 ; first load : %0 + 4 * 1 - 8
+ %89 = load float, ptr %88, align 4
+ %90 = sext i32 %.05.i to i64
+ %91 = getelementptr float, ptr %0, i64 %90 ; %0 + 4 * 1
+ %92 = getelementptr i8, ptr %91, i64 -4 ; second load : %0 + 4 * 1 - 4
+ %93 = load float, ptr %92, align 4
+ %94 = fcmp contract olt float %89, %93
+ %95 = trunc nsw i64 %indvars.iv.next.i to i32
+ %.1.i = select i1 %94, i32 %95, i32 %.05.i
+ %96 = add nsw i64 %86, -1
+ %97 = icmp samesign ugt i64 %86, 1
+ br i1 %97, label %.lr.ph.i, label %exit
+
+exit:
+ ret void
+}
More information about the llvm-commits
mailing list