[llvm] [llvm][instcombine] Add Missed Optimization for Folding Min Max intrinsic into PHI instruction (PR #84619)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 11 11:08:44 PDT 2024
https://github.com/PeterChou1 updated https://github.com/llvm/llvm-project/pull/84619
>From 93d329318edce86e5fcb81532c4a449984aa4e49 Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 01:44:24 -0500
Subject: [PATCH 1/4] [llvm][instcombine] adds missed fold optimization for min
max intrinsics in phi instcombine pass
---
.../InstCombine/InstCombineInternal.h | 3 +
.../Transforms/InstCombine/InstCombinePHI.cpp | 90 +++++++++++++++++++
2 files changed, 93 insertions(+)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 57148d719d9b61..c9221dae533346 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -624,6 +624,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldPHIArgLoadIntoPHI(PHINode &PN);
Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
+ Instruction *foldPHIWithMinMax(PHINode &PN);
+ Instruction *foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
+ ICmpInst::Predicate Pred);
/// If an integer typed PHI has only one use which is an IntToPtr operation,
/// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 46bca4b722a03a..7c5696ccd2efde 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -611,6 +611,93 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
return NewGEP;
}
+/// helper function for foldPHIWithMinMax
+Instruction *
+InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
+ ICmpInst::Predicate Pred) {
+
+ auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
+ if (!Val)
+ return std::nullopt;
+ if (match(Val, m_One()))
+ return true;
+ if (match(Val, m_Zero()))
+ return false;
+ return std::nullopt;
+ };
+
+ ICmpInst::Predicate SwappedPred =
+ ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
+ for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
+ if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
+ if (Pred != MinMax->getPredicate())
+ continue;
+
+ Value *X = MinMax->getLHS();
+ Value *Y = MinMax->getRHS();
+
+ SimplifyQuery Q = SQ.getWithInstruction(I);
+
+ auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
+ auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
+
+ if (!CmpXZ.has_value() && !CmpYZ.has_value())
+ continue;
+ if (CmpXZ.has_value() && CmpYZ.has_value())
+ continue;
+
+ if (!CmpXZ.has_value()) {
+ std::swap(X, Y);
+ std::swap(CmpXZ, CmpYZ);
+ }
+
+ switch (Pred) {
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_UGT:
+ // if X > Z
+ // %min = llvm.min ( X, Y )
+ // %phi = phi %min ... => %phi = phi X ..
+ // %cmp = icmp lt %phi, Z %cmp = icmp %phi, Z
+ // if X < Z
+ // %max = llvm.max ( X, Y )
+ // %phi = phi %max ... => %phi = phi X ..
+ // %cmp = icmp gt %phi, Z %cmp = icmp %phi, Z
+ if (CmpXZ.value()) {
+ if (MinMax->hasOneUse()) {
+ MinMax->eraseFromParent();
+ }
+ PN.setIncomingValue(OpNum, Y);
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return nullptr;
+}
+
+/// Fold max min intrinsic into PHI instruction
+Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
+
+ if (!PN.hasOneUse())
+ return nullptr;
+
+ // The PHI instruction must only be used by a ICmp Instruction
+ if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
+ Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
+ // case 1: icmp <op> %phi, %other
+ if (isa<PHINode>(Op0))
+ return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
+ // case 2: icmp <op> %intrinsic, %phi
+ else if (isa<PHINode>(Op1))
+ return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getPredicate());
+ }
+ return nullptr;
+}
+
/// Return true if we know that it is safe to sink the load out of the block
/// that defines it. This means that it must be obvious the value of the load is
/// not changed from the point of the load to the end of the block it is in.
@@ -1651,5 +1738,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
if (Value *Res = foldDependentIVs(PN, Builder))
return replaceInstUsesWith(PN, Res);
+ if (Instruction *Result = foldPHIWithMinMax(PN))
+ return Result;
+
return nullptr;
}
>From d2798eebe9ae1f2899ac863fe9626a9aa05f74bb Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 04:35:01 -0500
Subject: [PATCH 2/4] [llvm][instcombine] added test for folding min/max
intrinsics into phi instructions
---
.../Transforms/InstCombine/InstCombinePHI.cpp | 13 +-
.../Transforms/InstCombine/fold-phi-minmax.ll | 314 ++++++++++++++++++
2 files changed, 321 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 7c5696ccd2efde..71f8cbeb9a9dbd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -646,6 +646,7 @@ InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
if (CmpXZ.has_value() && CmpYZ.has_value())
continue;
+ // swap XZ with YZ so XZ always has value
if (!CmpXZ.has_value()) {
std::swap(X, Y);
std::swap(CmpXZ, CmpYZ);
@@ -656,13 +657,13 @@ InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT:
- // if X > Z
+ // if X >= Z
// %min = llvm.min ( X, Y )
- // %phi = phi %min ... => %phi = phi X ..
+ // %phi = phi %min ... => %phi = phi Y ..
// %cmp = icmp lt %phi, Z %cmp = icmp %phi, Z
- // if X < Z
+ // if X <= Z
// %max = llvm.max ( X, Y )
- // %phi = phi %max ... => %phi = phi X ..
+ // %phi = phi %max ... => %phi = phi Y ..
// %cmp = icmp gt %phi, Z %cmp = icmp %phi, Z
if (CmpXZ.value()) {
if (MinMax->hasOneUse()) {
@@ -688,12 +689,12 @@ Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
// The PHI instruction must only be used by a ICmp Instruction
if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
- // case 1: icmp <op> %phi, %other
+ // case 1: icmp <op> %phi, %intrinsic
if (isa<PHINode>(Op0))
return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
// case 2: icmp <op> %intrinsic, %phi
else if (isa<PHINode>(Op1))
- return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getPredicate());
+ return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getSwappedPredicate());
}
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
new file mode 100644
index 00000000000000..55b464cdc58248
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
@@ -0,0 +1,314 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+
+; test phi combine less than (equal)
+define i1 @src0(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src0(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 6, i32 %a)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine less than (swapped) good
+define i1 @src1(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src1(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 6, i32 %a)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 6, %ind
+ ret i1 %cmp
+}
+
+; test phi combine less than (swapped) bad
+define i1 @src2(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 6)
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 6, i32 %a)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 6, %ind
+ ret i1 %cmp
+}
+
+
+; test phi combine less than (reversed)
+define i1 @src3(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src3(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine less than (over)
+define i1 @src4(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src4(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 7)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine less than (under)
+define i1 @src5(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src5(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 5)
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 5)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine greater than (equal)
+define i1 @src6(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src6(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 %a, i32 6)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine greater than (over)
+define i1 @src7(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src7(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 7)
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 %a, i32 7)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine greater than (under)
+define i1 @src8(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src8(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 %a, i32 5)
+ br label %loop
+
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 %ind, 6
+ ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-equal) good
+define i1 @src9(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src9(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 6, %ind
+ ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-over) good
+define i1 @src11(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src11(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 7)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 6, %ind
+ ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-under) bad
+define i1 @src12(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src12(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 5)
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 5)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 6, %ind
+ ret i1 %cmp
+}
+
+; test phi combine less than (swapped-equal) bad
+define i1 @src13(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src13(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 6)
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 6, %ind
+ ret i1 %cmp
+}
>From 6de82a748a290c6e879a4922fa7185c98ce829fa Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 04:47:01 -0500
Subject: [PATCH 3/4] [llvm][instcombine] format file for InstCombinePHI
---
llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 71f8cbeb9a9dbd..9cae84741722f6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -694,7 +694,8 @@ Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
// case 2: icmp <op> %intrinsic, %phi
else if (isa<PHINode>(Op1))
- return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getSwappedPredicate());
+ return foldPHIWithMinMaxHelper(PN, ICmp, Op0,
+ ICmp->getSwappedPredicate());
}
return nullptr;
}
>From 4156401962c0a45ff63aa98f87aaab7eabbb8c2f Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Mon, 11 Mar 2024 14:03:48 -0400
Subject: [PATCH 4/4] [llvm][instcombine] refactor optimization for folding PHI
ICmp to InstCombineCompares
---
.../InstCombine/InstCombineCompares.cpp | 87 ++++++++++++++++++
.../InstCombine/InstCombineInternal.h | 7 +-
.../Transforms/InstCombine/InstCombinePHI.cpp | 92 -------------------
.../Transforms/InstCombine/fold-phi-minmax.ll | 73 ++++++++++++---
4 files changed, 150 insertions(+), 109 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fc2688f425bb8f..34be7356a0a886 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6908,6 +6908,90 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
return nullptr;
}
+/// Helper function for foldICmpPHIWithMinMax
+Instruction *InstCombinerImpl::foldICmpPHIWithMinMaxHelper(
+ PHINode &PN, Instruction &I, Value *Z, ICmpInst::Predicate Pred) {
+
+ if (!PN.hasOneUse())
+ return nullptr;
+
+ bool Changed = false;
+ auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
+ if (!Val)
+ return std::nullopt;
+ if (match(Val, m_One()))
+ return true;
+ if (match(Val, m_Zero()))
+ return false;
+ return std::nullopt;
+ };
+
+ ICmpInst::Predicate SwappedPred =
+ ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
+ for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
+ if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
+ if (Pred != MinMax->getPredicate())
+ continue;
+
+ Value *X = MinMax->getLHS();
+ Value *Y = MinMax->getRHS();
+
+ SimplifyQuery Q =
+ SQ.getWithInstruction(PN.getIncomingBlock(OpNum)->getTerminator());
+
+ auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
+ auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
+
+ if (!CmpXZ.has_value() && !CmpYZ.has_value())
+ continue;
+
+ if (CmpXZ.has_value() && CmpYZ.has_value()) {
+ auto CmpXY = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Y, Q));
+ if (!CmpXY.has_value())
+ continue;
+ // take the greater or lesser of X,Y depending on the intrinsic
+ if (!CmpXY.value()) {
+ std::swap(X, Y);
+ std::swap(CmpXZ, CmpYZ);
+ }
+ }
+ // swap XZ with YZ if XZ has no value
+ else if (!CmpXZ.has_value()) {
+ std::swap(X, Y);
+ std::swap(CmpXZ, CmpYZ);
+ }
+ // if X >= Z
+ // %min = llvm.min ( X, Y )
+ // %phi = phi %min ... => %phi = phi Y ..
+ // %cmp = icmp lt %phi, Z %cmp = icmp %phi, Z
+ // if X <= Z
+ // %max = llvm.max ( X, Y )
+ // %phi = phi %max ... => %phi = phi Y ..
+ // %cmp = icmp gt %phi, Z %cmp = icmp %phi, Z
+ if (CmpXZ.value()) {
+ Changed = true;
+ replaceOperand(PN, OpNum, Y);
+ }
+ }
+ }
+ return Changed ? &I : nullptr;
+}
+
+/// folds max min intrinsic into PHI instruction based on current ICmp
+/// Instruction
+Instruction *InstCombinerImpl::foldICmpPHIWithMinMax(ICmpInst &Cmp) {
+ Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1);
+ // case 1: icmp <op> %phi, %intrinsic
+ if (PHINode *PNOp0 = dyn_cast<PHINode>(Op0))
+ return foldICmpPHIWithMinMaxHelper(*PNOp0, Cmp, Op1, Cmp.getPredicate());
+ // case 2: icmp <op> %intrinsic, %phi
+ else if (PHINode *PNOp1 = dyn_cast<PHINode>(Op1))
+ return foldICmpPHIWithMinMaxHelper(*PNOp1, Cmp, Op0,
+ Cmp.getSwappedPredicate());
+
+ return nullptr;
+}
+
// This helper will be called with icmp operands in both orders.
Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
Value *Op0, Value *Op1,
@@ -7287,6 +7371,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
return Res;
+ if (Instruction *Res = foldICmpPHIWithMinMax(I))
+ return Res;
+
return Changed ? &I : nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c9221dae533346..f5dc295aafd121 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -624,9 +624,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldPHIArgLoadIntoPHI(PHINode &PN);
Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
- Instruction *foldPHIWithMinMax(PHINode &PN);
- Instruction *foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
- ICmpInst::Predicate Pred);
/// If an integer typed PHI has only one use which is an IntToPtr operation,
/// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
@@ -724,6 +721,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldICmpCommutative(ICmpInst::Predicate Pred, Value *Op0,
Value *Op1, ICmpInst &CxtI);
+ Instruction *foldICmpPHIWithMinMaxHelper(PHINode &PN, Instruction &I,
+ Value *Z, ICmpInst::Predicate Pred);
+ Instruction *foldICmpPHIWithMinMax(ICmpInst &Cmp);
+
// Helpers of visitSelectInst().
Instruction *foldSelectOfBools(SelectInst &SI);
Instruction *foldSelectExtConst(SelectInst &Sel);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 9cae84741722f6..46bca4b722a03a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -611,95 +611,6 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
return NewGEP;
}
-/// helper function for foldPHIWithMinMax
-Instruction *
-InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
- ICmpInst::Predicate Pred) {
-
- auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
- if (!Val)
- return std::nullopt;
- if (match(Val, m_One()))
- return true;
- if (match(Val, m_Zero()))
- return false;
- return std::nullopt;
- };
-
- ICmpInst::Predicate SwappedPred =
- ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
- for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
- if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
- if (Pred != MinMax->getPredicate())
- continue;
-
- Value *X = MinMax->getLHS();
- Value *Y = MinMax->getRHS();
-
- SimplifyQuery Q = SQ.getWithInstruction(I);
-
- auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
- auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
-
- if (!CmpXZ.has_value() && !CmpYZ.has_value())
- continue;
- if (CmpXZ.has_value() && CmpYZ.has_value())
- continue;
-
- // swap XZ with YZ so XZ always has value
- if (!CmpXZ.has_value()) {
- std::swap(X, Y);
- std::swap(CmpXZ, CmpYZ);
- }
-
- switch (Pred) {
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_UGT:
- // if X >= Z
- // %min = llvm.min ( X, Y )
- // %phi = phi %min ... => %phi = phi Y ..
- // %cmp = icmp lt %phi, Z %cmp = icmp %phi, Z
- // if X <= Z
- // %max = llvm.max ( X, Y )
- // %phi = phi %max ... => %phi = phi Y ..
- // %cmp = icmp gt %phi, Z %cmp = icmp %phi, Z
- if (CmpXZ.value()) {
- if (MinMax->hasOneUse()) {
- MinMax->eraseFromParent();
- }
- PN.setIncomingValue(OpNum, Y);
- }
- break;
- default:
- break;
- }
- }
- }
- return nullptr;
-}
-
-/// Fold max min intrinsic into PHI instruction
-Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
-
- if (!PN.hasOneUse())
- return nullptr;
-
- // The PHI instruction must only be used by a ICmp Instruction
- if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
- Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
- // case 1: icmp <op> %phi, %intrinsic
- if (isa<PHINode>(Op0))
- return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
- // case 2: icmp <op> %intrinsic, %phi
- else if (isa<PHINode>(Op1))
- return foldPHIWithMinMaxHelper(PN, ICmp, Op0,
- ICmp->getSwappedPredicate());
- }
- return nullptr;
-}
-
/// Return true if we know that it is safe to sink the load out of the block
/// that defines it. This means that it must be obvious the value of the load is
/// not changed from the point of the load to the end of the block it is in.
@@ -1740,8 +1651,5 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
if (Value *Res = foldDependentIVs(PN, Builder))
return replaceInstUsesWith(PN, Res);
- if (Instruction *Result = foldPHIWithMinMax(PN))
- return Result;
-
return nullptr;
}
diff --git a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
index 55b464cdc58248..c7cf0c9fc4f18f 100644
--- a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
@@ -1,9 +1,8 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
-
; test phi combine less than (equal)
-define i1 @src0(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src0(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src0(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -26,7 +25,7 @@ loop:
}
; test phi combine less than (swapped) good
-define i1 @src1(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src1(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src1(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -49,7 +48,7 @@ loop:
}
; test phi combine less than (swapped) bad
-define i1 @src2(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src2(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src2(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -74,7 +73,7 @@ loop:
; test phi combine less than (reversed)
-define i1 @src3(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src3(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src3(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -98,7 +97,7 @@ loop:
}
; test phi combine less than (over)
-define i1 @src4(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src4(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src4(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -122,7 +121,7 @@ loop:
}
; test phi combine less than (under)
-define i1 @src5(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src5(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src5(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -147,7 +146,7 @@ loop:
}
; test phi combine greater than (equal)
-define i1 @src6(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src6(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src6(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -171,7 +170,7 @@ loop:
}
; test phi combine greater than (over)
-define i1 @src7(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src7(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src7(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -196,7 +195,7 @@ loop:
}
; test phi combine greater than (under)
-define i1 @src8(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src8(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src8(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -220,7 +219,7 @@ loop:
}
; test phi combine greater than (swapped-equal) good
-define i1 @src9(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src9(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src9(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -243,7 +242,7 @@ loop:
}
; test phi combine greater than (swapped-over) good
-define i1 @src11(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src11(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src11(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -266,7 +265,7 @@ loop:
}
; test phi combine greater than (swapped-under) bad
-define i1 @src12(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src12(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src12(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -290,7 +289,7 @@ loop:
}
; test phi combine less than (swapped-equal) bad
-define i1 @src13(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src13(i32 %a, i32 %b, i1 %c) {
; CHECK-LABEL: @src13(
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -312,3 +311,49 @@ loop:
%cmp = icmp ult i32 6, %ind
ret i1 %cmp
}
+
+; test phi combine (min) both value defined
+define i1 @src14(i32 %a, i32 %b, i1 %c) {
+; CHECK-LABEL: @src14(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ 6, [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umin.i32(i32 7, i32 6)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ult i32 6, %ind
+ ret i1 %cmp
+}
+
+; test phi combine (max) both value defined
+define i1 @src15(i32 %a, i32 %b, i1 %c) {
+; CHECK-LABEL: @src15(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK: then:
+; CHECK-NEXT: br label [[LOOP]]
+; CHECK: loop:
+; CHECK-NEXT: [[IND:%.*]] = phi i32 [ 7, [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+entry:
+ br i1 %c, label %then, label %loop
+then:
+ %min = call i32 @llvm.umax.i32(i32 7, i32 6)
+ br label %loop
+loop:
+ %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+ %cmp = icmp ugt i32 6, %ind
+ ret i1 %cmp
+}
More information about the llvm-commits
mailing list