[llvm] [llvm][instcombine] Add Missed Optimization for Folding Min Max intrinsic into PHI instruction (PR #84619)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 11 11:08:44 PDT 2024


https://github.com/PeterChou1 updated https://github.com/llvm/llvm-project/pull/84619

>From 93d329318edce86e5fcb81532c4a449984aa4e49 Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 01:44:24 -0500
Subject: [PATCH 1/4] [llvm][instcombine] adds missed fold optimization for min
 max intrinsics in phi instcombine pass

---
 .../InstCombine/InstCombineInternal.h         |  3 +
 .../Transforms/InstCombine/InstCombinePHI.cpp | 90 +++++++++++++++++++
 2 files changed, 93 insertions(+)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 57148d719d9b61..c9221dae533346 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -624,6 +624,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldPHIArgLoadIntoPHI(PHINode &PN);
   Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
   Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
+  Instruction *foldPHIWithMinMax(PHINode &PN);
+  Instruction *foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
+                                       ICmpInst::Predicate Pred);
 
   /// If an integer typed PHI has only one use which is an IntToPtr operation,
   /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 46bca4b722a03a..7c5696ccd2efde 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -611,6 +611,93 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
   return NewGEP;
 }
 
+/// helper function for foldPHIWithMinMax
+Instruction *
+InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
+                                          ICmpInst::Predicate Pred) {
+
+  auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
+    if (!Val)
+      return std::nullopt;
+    if (match(Val, m_One()))
+      return true;
+    if (match(Val, m_Zero()))
+      return false;
+    return std::nullopt;
+  };
+
+  ICmpInst::Predicate SwappedPred =
+      ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
+  for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
+    if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
+      if (Pred != MinMax->getPredicate())
+        continue;
+
+      Value *X = MinMax->getLHS();
+      Value *Y = MinMax->getRHS();
+
+      SimplifyQuery Q = SQ.getWithInstruction(I);
+
+      auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
+      auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
+
+      if (!CmpXZ.has_value() && !CmpYZ.has_value())
+        continue;
+      if (CmpXZ.has_value() && CmpYZ.has_value())
+        continue;
+
+      if (!CmpXZ.has_value()) {
+        std::swap(X, Y);
+        std::swap(CmpXZ, CmpYZ);
+      }
+
+      switch (Pred) {
+      case ICmpInst::ICMP_SLT:
+      case ICmpInst::ICMP_ULT:
+      case ICmpInst::ICMP_SGT:
+      case ICmpInst::ICMP_UGT:
+        // if X > Z
+        // %min = llvm.min ( X, Y )
+        // %phi = phi %min ...         =>  %phi = phi X ..
+        // %cmp = icmp lt %phi, Z          %cmp = icmp %phi, Z
+        // if X < Z
+        // %max = llvm.max ( X, Y )
+        // %phi = phi %max ...         =>  %phi = phi X ..
+        // %cmp = icmp gt %phi, Z          %cmp = icmp %phi, Z
+        if (CmpXZ.value()) {
+          if (MinMax->hasOneUse()) {
+            MinMax->eraseFromParent();
+          }
+          PN.setIncomingValue(OpNum, Y);
+        }
+        break;
+      default:
+        break;
+      }
+    }
+  }
+  return nullptr;
+}
+
+/// Fold max min intrinsic into PHI instruction
+Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
+
+  if (!PN.hasOneUse())
+    return nullptr;
+
+  // The PHI instruction must only be used by a ICmp Instruction
+  if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
+    Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
+    // case 1: icmp <op> %phi, %other
+    if (isa<PHINode>(Op0))
+      return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
+    // case 2: icmp <op> %intrinsic, %phi
+    else if (isa<PHINode>(Op1))
+      return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getPredicate());
+  }
+  return nullptr;
+}
+
 /// Return true if we know that it is safe to sink the load out of the block
 /// that defines it. This means that it must be obvious the value of the load is
 /// not changed from the point of the load to the end of the block it is in.
@@ -1651,5 +1738,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
   if (Value *Res = foldDependentIVs(PN, Builder))
     return replaceInstUsesWith(PN, Res);
 
+  if (Instruction *Result = foldPHIWithMinMax(PN))
+    return Result;
+
   return nullptr;
 }

>From d2798eebe9ae1f2899ac863fe9626a9aa05f74bb Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 04:35:01 -0500
Subject: [PATCH 2/4] [llvm][instcombine] added test for folding min/max
 intrinsics into phi instructions

---
 .../Transforms/InstCombine/InstCombinePHI.cpp |  13 +-
 .../Transforms/InstCombine/fold-phi-minmax.ll | 314 ++++++++++++++++++
 2 files changed, 321 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/Transforms/InstCombine/fold-phi-minmax.ll

diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 7c5696ccd2efde..71f8cbeb9a9dbd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -646,6 +646,7 @@ InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
       if (CmpXZ.has_value() && CmpYZ.has_value())
         continue;
 
+      // swap XZ with YZ so XZ always has value
       if (!CmpXZ.has_value()) {
         std::swap(X, Y);
         std::swap(CmpXZ, CmpYZ);
@@ -656,13 +657,13 @@ InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
       case ICmpInst::ICMP_ULT:
       case ICmpInst::ICMP_SGT:
       case ICmpInst::ICMP_UGT:
-        // if X > Z
+        // if X >= Z
         // %min = llvm.min ( X, Y )
-        // %phi = phi %min ...         =>  %phi = phi X ..
+        // %phi = phi %min ...         =>  %phi = phi Y ..
         // %cmp = icmp lt %phi, Z          %cmp = icmp %phi, Z
-        // if X < Z
+        // if X <= Z
         // %max = llvm.max ( X, Y )
-        // %phi = phi %max ...         =>  %phi = phi X ..
+        // %phi = phi %max ...         =>  %phi = phi Y ..
         // %cmp = icmp gt %phi, Z          %cmp = icmp %phi, Z
         if (CmpXZ.value()) {
           if (MinMax->hasOneUse()) {
@@ -688,12 +689,12 @@ Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
   // The PHI instruction must only be used by a ICmp Instruction
   if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
     Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
-    // case 1: icmp <op> %phi, %other
+    // case 1: icmp <op> %phi, %intrinsic
     if (isa<PHINode>(Op0))
       return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
     // case 2: icmp <op> %intrinsic, %phi
     else if (isa<PHINode>(Op1))
-      return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getPredicate());
+      return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getSwappedPredicate());
   }
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
new file mode 100644
index 00000000000000..55b464cdc58248
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
@@ -0,0 +1,314 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+
+; test phi combine less than (equal)
+define i1 @src0(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 6, i32 %a)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine less than (swapped) good
+define i1 @src1(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src1(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 6, i32 %a)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 6, %ind
+  ret i1 %cmp
+}
+
+; test phi combine less than (swapped) bad
+define i1 @src2(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src2(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 6)
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 6, i32 %a)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 6, %ind
+  ret i1 %cmp
+}
+
+
+; test phi combine less than (reversed)
+define i1 @src3(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src3(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine less than (over)
+define i1 @src4(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src4(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 7)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine less than (under)
+define i1 @src5(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src5(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 5)
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 5)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine greater than (equal)
+define i1 @src6(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src6(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 %a, i32 6)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine greater than (over)
+define i1 @src7(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src7(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umax.i32(i32 [[A:%.*]], i32 7)
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 %a, i32 7)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine greater than (under)
+define i1 @src8(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src8(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 %a, i32 5)
+  br label %loop
+
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 %ind, 6
+  ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-equal) good
+define i1 @src9(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src9(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 6, %ind
+  ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-over) good
+define i1 @src11(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src11(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[A:%.*]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 7)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 6, %ind
+  ret i1 %cmp
+}
+
+; test phi combine greater than (swapped-under) bad
+define i1 @src12(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src12(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 5)
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 5)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 6, %ind
+  ret i1 %cmp
+}
+
+; test phi combine less than (swapped-equal) bad
+define i1 @src13(i32 %a, i32 %b, i1 %c) #0 {
+; CHECK-LABEL: @src13(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 6)
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ [[MIN]], [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 %a, i32 6)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 6, %ind
+  ret i1 %cmp
+}

>From 6de82a748a290c6e879a4922fa7185c98ce829fa Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Sat, 9 Mar 2024 04:47:01 -0500
Subject: [PATCH 3/4] [llvm][instcombine] format file for InstCombinePHI

---
 llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 71f8cbeb9a9dbd..9cae84741722f6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -694,7 +694,8 @@ Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
       return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
     // case 2: icmp <op> %intrinsic, %phi
     else if (isa<PHINode>(Op1))
-      return foldPHIWithMinMaxHelper(PN, ICmp, Op0, ICmp->getSwappedPredicate());
+      return foldPHIWithMinMaxHelper(PN, ICmp, Op0,
+                                     ICmp->getSwappedPredicate());
   }
   return nullptr;
 }

>From 4156401962c0a45ff63aa98f87aaab7eabbb8c2f Mon Sep 17 00:00:00 2001
From: PeterChou1 <peter.chou at mail.utoronto.ca>
Date: Mon, 11 Mar 2024 14:03:48 -0400
Subject: [PATCH 4/4] [llvm][instcombine] refactor optimization for folding PHI
 ICmp to InstCombineCompares

---
 .../InstCombine/InstCombineCompares.cpp       | 87 ++++++++++++++++++
 .../InstCombine/InstCombineInternal.h         |  7 +-
 .../Transforms/InstCombine/InstCombinePHI.cpp | 92 -------------------
 .../Transforms/InstCombine/fold-phi-minmax.ll | 73 ++++++++++++---
 4 files changed, 150 insertions(+), 109 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fc2688f425bb8f..34be7356a0a886 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6908,6 +6908,90 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
   return nullptr;
 }
 
+/// Helper function for foldICmpPHIWithMinMax
+Instruction *InstCombinerImpl::foldICmpPHIWithMinMaxHelper(
+    PHINode &PN, Instruction &I, Value *Z, ICmpInst::Predicate Pred) {
+
+  if (!PN.hasOneUse())
+    return nullptr;
+
+  bool Changed = false;
+  auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
+    if (!Val)
+      return std::nullopt;
+    if (match(Val, m_One()))
+      return true;
+    if (match(Val, m_Zero()))
+      return false;
+    return std::nullopt;
+  };
+
+  ICmpInst::Predicate SwappedPred =
+      ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
+  for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
+    if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
+      if (Pred != MinMax->getPredicate())
+        continue;
+
+      Value *X = MinMax->getLHS();
+      Value *Y = MinMax->getRHS();
+
+      SimplifyQuery Q =
+          SQ.getWithInstruction(PN.getIncomingBlock(OpNum)->getTerminator());
+
+      auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
+      auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
+
+      if (!CmpXZ.has_value() && !CmpYZ.has_value())
+        continue;
+
+      if (CmpXZ.has_value() && CmpYZ.has_value()) {
+        auto CmpXY = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Y, Q));
+        if (!CmpXY.has_value())
+          continue;
+        // take the greater or lesser of X,Y depending on the intrinsic
+        if (!CmpXY.value()) {
+          std::swap(X, Y);
+          std::swap(CmpXZ, CmpYZ);
+        }
+      }
+      // swap XZ with YZ if XZ has no value
+      else if (!CmpXZ.has_value()) {
+        std::swap(X, Y);
+        std::swap(CmpXZ, CmpYZ);
+      }
+      // if X >= Z
+      // %min = llvm.min ( X, Y )
+      // %phi = phi %min ...         =>  %phi = phi Y ..
+      // %cmp = icmp lt %phi, Z          %cmp = icmp %phi, Z
+      // if X <= Z
+      // %max = llvm.max ( X, Y )
+      // %phi = phi %max ...         =>  %phi = phi Y ..
+      // %cmp = icmp gt %phi, Z          %cmp = icmp %phi, Z
+      if (CmpXZ.value()) {
+        Changed = true;
+        replaceOperand(PN, OpNum, Y);
+      }
+    }
+  }
+  return Changed ? &I : nullptr;
+}
+
+/// folds max min intrinsic into PHI instruction based on current ICmp
+/// Instruction
+Instruction *InstCombinerImpl::foldICmpPHIWithMinMax(ICmpInst &Cmp) {
+  Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1);
+  // case 1: icmp <op> %phi, %intrinsic
+  if (PHINode *PNOp0 = dyn_cast<PHINode>(Op0))
+    return foldICmpPHIWithMinMaxHelper(*PNOp0, Cmp, Op1, Cmp.getPredicate());
+  // case 2: icmp <op> %intrinsic, %phi
+  else if (PHINode *PNOp1 = dyn_cast<PHINode>(Op1))
+    return foldICmpPHIWithMinMaxHelper(*PNOp1, Cmp, Op0,
+                                       Cmp.getSwappedPredicate());
+
+  return nullptr;
+}
+
 // This helper will be called with icmp operands in both orders.
 Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
                                                    Value *Op0, Value *Op1,
@@ -7287,6 +7371,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = foldReductionIdiom(I, Builder, DL))
     return Res;
 
+  if (Instruction *Res = foldICmpPHIWithMinMax(I))
+    return Res;
+
   return Changed ? &I : nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c9221dae533346..f5dc295aafd121 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -624,9 +624,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldPHIArgLoadIntoPHI(PHINode &PN);
   Instruction *foldPHIArgZextsIntoPHI(PHINode &PN);
   Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN);
-  Instruction *foldPHIWithMinMax(PHINode &PN);
-  Instruction *foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
-                                       ICmpInst::Predicate Pred);
 
   /// If an integer typed PHI has only one use which is an IntToPtr operation,
   /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
@@ -724,6 +721,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpCommutative(ICmpInst::Predicate Pred, Value *Op0,
                                    Value *Op1, ICmpInst &CxtI);
 
+  Instruction *foldICmpPHIWithMinMaxHelper(PHINode &PN, Instruction &I,
+                                           Value *Z, ICmpInst::Predicate Pred);
+  Instruction *foldICmpPHIWithMinMax(ICmpInst &Cmp);
+
   // Helpers of visitSelectInst().
   Instruction *foldSelectOfBools(SelectInst &SI);
   Instruction *foldSelectExtConst(SelectInst &Sel);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 9cae84741722f6..46bca4b722a03a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -611,95 +611,6 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
   return NewGEP;
 }
 
-/// helper function for foldPHIWithMinMax
-Instruction *
-InstCombinerImpl::foldPHIWithMinMaxHelper(PHINode &PN, Instruction *I, Value *Z,
-                                          ICmpInst::Predicate Pred) {
-
-  auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> {
-    if (!Val)
-      return std::nullopt;
-    if (match(Val, m_One()))
-      return true;
-    if (match(Val, m_Zero()))
-      return false;
-    return std::nullopt;
-  };
-
-  ICmpInst::Predicate SwappedPred =
-      ICmpInst::getNonStrictPredicate(ICmpInst::getSwappedPredicate(Pred));
-  for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) {
-    if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(PN.getIncomingValue(OpNum))) {
-      if (Pred != MinMax->getPredicate())
-        continue;
-
-      Value *X = MinMax->getLHS();
-      Value *Y = MinMax->getRHS();
-
-      SimplifyQuery Q = SQ.getWithInstruction(I);
-
-      auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, X, Z, Q));
-      auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(SwappedPred, Y, Z, Q));
-
-      if (!CmpXZ.has_value() && !CmpYZ.has_value())
-        continue;
-      if (CmpXZ.has_value() && CmpYZ.has_value())
-        continue;
-
-      // swap XZ with YZ so XZ always has value
-      if (!CmpXZ.has_value()) {
-        std::swap(X, Y);
-        std::swap(CmpXZ, CmpYZ);
-      }
-
-      switch (Pred) {
-      case ICmpInst::ICMP_SLT:
-      case ICmpInst::ICMP_ULT:
-      case ICmpInst::ICMP_SGT:
-      case ICmpInst::ICMP_UGT:
-        // if X >= Z
-        // %min = llvm.min ( X, Y )
-        // %phi = phi %min ...         =>  %phi = phi Y ..
-        // %cmp = icmp lt %phi, Z          %cmp = icmp %phi, Z
-        // if X <= Z
-        // %max = llvm.max ( X, Y )
-        // %phi = phi %max ...         =>  %phi = phi Y ..
-        // %cmp = icmp gt %phi, Z          %cmp = icmp %phi, Z
-        if (CmpXZ.value()) {
-          if (MinMax->hasOneUse()) {
-            MinMax->eraseFromParent();
-          }
-          PN.setIncomingValue(OpNum, Y);
-        }
-        break;
-      default:
-        break;
-      }
-    }
-  }
-  return nullptr;
-}
-
-/// Fold max min intrinsic into PHI instruction
-Instruction *InstCombinerImpl::foldPHIWithMinMax(PHINode &PN) {
-
-  if (!PN.hasOneUse())
-    return nullptr;
-
-  // The PHI instruction must only be used by a ICmp Instruction
-  if (ICmpInst *ICmp = dyn_cast<ICmpInst>(PN.getUniqueUndroppableUser())) {
-    Value *Op0 = ICmp->getOperand(0), *Op1 = ICmp->getOperand(1);
-    // case 1: icmp <op> %phi, %intrinsic
-    if (isa<PHINode>(Op0))
-      return foldPHIWithMinMaxHelper(PN, ICmp, Op1, ICmp->getPredicate());
-    // case 2: icmp <op> %intrinsic, %phi
-    else if (isa<PHINode>(Op1))
-      return foldPHIWithMinMaxHelper(PN, ICmp, Op0,
-                                     ICmp->getSwappedPredicate());
-  }
-  return nullptr;
-}
-
 /// Return true if we know that it is safe to sink the load out of the block
 /// that defines it. This means that it must be obvious the value of the load is
 /// not changed from the point of the load to the end of the block it is in.
@@ -1740,8 +1651,5 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
   if (Value *Res = foldDependentIVs(PN, Builder))
     return replaceInstUsesWith(PN, Res);
 
-  if (Instruction *Result = foldPHIWithMinMax(PN))
-    return Result;
-
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
index 55b464cdc58248..c7cf0c9fc4f18f 100644
--- a/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
+++ b/llvm/test/Transforms/InstCombine/fold-phi-minmax.ll
@@ -1,9 +1,8 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt < %s -passes=instcombine -S | FileCheck %s
 
-
 ; test phi combine less than (equal)
-define i1 @src0(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src0(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src0(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -26,7 +25,7 @@ loop:
 }
 
 ; test phi combine less than (swapped) good
-define i1 @src1(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src1(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src1(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -49,7 +48,7 @@ loop:
 }
 
 ; test phi combine less than (swapped) bad
-define i1 @src2(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src2(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src2(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -74,7 +73,7 @@ loop:
 
 
 ; test phi combine less than (reversed)
-define i1 @src3(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src3(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src3(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -98,7 +97,7 @@ loop:
 }
 
 ; test phi combine less than (over)
-define i1 @src4(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src4(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src4(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -122,7 +121,7 @@ loop:
 }
 
 ; test phi combine less than (under)
-define i1 @src5(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src5(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src5(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -147,7 +146,7 @@ loop:
 }
 
 ; test phi combine greater than (equal)
-define i1 @src6(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src6(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src6(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -171,7 +170,7 @@ loop:
 }
 
 ; test phi combine greater than (over)
-define i1 @src7(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src7(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src7(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -196,7 +195,7 @@ loop:
 }
 
 ; test phi combine greater than (under)
-define i1 @src8(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src8(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src8(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -220,7 +219,7 @@ loop:
 }
 
 ; test phi combine greater than (swapped-equal) good
-define i1 @src9(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src9(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src9(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -243,7 +242,7 @@ loop:
 }
 
 ; test phi combine greater than (swapped-over) good
-define i1 @src11(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src11(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src11(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -266,7 +265,7 @@ loop:
 }
 
 ; test phi combine greater than (swapped-under) bad
-define i1 @src12(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src12(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src12(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -290,7 +289,7 @@ loop:
 }
 
 ; test phi combine less than (swapped-equal) bad
-define i1 @src13(i32 %a, i32 %b, i1 %c) #0 {
+define i1 @src13(i32 %a, i32 %b, i1 %c) {
 ; CHECK-LABEL: @src13(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
@@ -312,3 +311,49 @@ loop:
   %cmp = icmp ult i32 6, %ind
   ret i1 %cmp
 }
+
+; test phi combine (min) both value defined
+define i1 @src14(i32 %a, i32 %b, i1 %c) {
+; CHECK-LABEL: @src14(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ 6, [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umin.i32(i32 7, i32 6)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ult i32 6, %ind
+  ret i1 %cmp
+}
+
+; test phi combine (max) both value defined
+define i1 @src15(i32 %a, i32 %b, i1 %c) {
+; CHECK-LABEL: @src15(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br i1 [[C:%.*]], label [[THEN:%.*]], label [[LOOP:%.*]]
+; CHECK:       then:
+; CHECK-NEXT:    br label [[LOOP]]
+; CHECK:       loop:
+; CHECK-NEXT:    [[IND:%.*]] = phi i32 [ 7, [[THEN]] ], [ [[B:%.*]], [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[IND]], 6
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+entry:
+  br i1 %c, label %then, label %loop
+then:
+  %min = call i32 @llvm.umax.i32(i32 7, i32 6)
+  br label %loop
+loop:
+  %ind = phi i32 [ %min, %then ], [ %b, %entry ]
+  %cmp = icmp ugt i32 6, %ind
+  ret i1 %cmp
+}



More information about the llvm-commits mailing list