[llvm] 92a11eb - [ConstraintElim] Add facts implied by MinMaxIntrinsic

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 24 00:04:52 PDT 2023


Author: Yingwei Zheng
Date: 2023-07-24T15:03:34+08:00
New Revision: 92a11eb32c92d10132f685f9896e8f044c4c2f02

URL: https://github.com/llvm/llvm-project/commit/92a11eb32c92d10132f685f9896e8f044c4c2f02
DIFF: https://github.com/llvm/llvm-project/commit/92a11eb32c92d10132f685f9896e8f044c4c2f02.diff

LOG: [ConstraintElim] Add facts implied by MinMaxIntrinsic

Fixes https://github.com/llvm/llvm-project/issues/63896 and https://github.com/rust-lang/rust/issues/113757.
This patch adds facts implied by llvm.smin/smax/umin/umax intrinsics.

Reviewed By: fhahn

Differential Revision: https://reviews.llvm.org/D155412

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
    llvm/test/Transforms/ConstraintElimination/minmax.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index e5773438fe5935..15628d32280d8e 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -784,6 +784,11 @@ void State::addInfoFor(BasicBlock &BB) {
       continue;
     }
 
+    if (isa<MinMaxIntrinsic>(&I)) {
+      WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I));
+      continue;
+    }
+
     Value *Cond;
     // For now, just handle assumes with a single compare as condition.
     if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) &&
@@ -1363,22 +1368,14 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
     }
 
     LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n");
-    ICmpInst::Predicate Pred;
-    Value *A, *B;
-    Value *Cmp = CB.Inst;
-    match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
-    if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+    auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
       if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
         LLVM_DEBUG(
             dbgs()
             << "Skip adding constraint because system has too many rows.\n");
-        continue;
+        return;
       }
 
-      // Use the inverse predicate if required.
-      if (CB.Not)
-        Pred = CmpInst::getInversePredicate(Pred);
-
       Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
       if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size())
         ReproducerCondStack.emplace_back(Pred, A, B);
@@ -1394,6 +1391,25 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
                                            nullptr, nullptr);
         }
       }
+    };
+
+    ICmpInst::Predicate Pred;
+    if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
+      Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
+      AddFact(Pred, MinMax, MinMax->getLHS());
+      AddFact(Pred, MinMax, MinMax->getRHS());
+      continue;
+    }
+
+    Value *A, *B;
+    Value *Cmp = CB.Inst;
+    match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
+    if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+      // Use the inverse predicate if required.
+      if (CB.Not)
+        Pred = CmpInst::getInversePredicate(Pred);
+
+      AddFact(Pred, A, B);
     }
   }
 

diff  --git a/llvm/test/Transforms/ConstraintElimination/minmax.ll b/llvm/test/Transforms/ConstraintElimination/minmax.ll
index 3dbd0f8f62d23c..43a9b6931b3292 100644
--- a/llvm/test/Transforms/ConstraintElimination/minmax.ll
+++ b/llvm/test/Transforms/ConstraintElimination/minmax.ll
@@ -11,7 +11,7 @@ define i1 @umax_ugt(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -39,7 +39,7 @@ define i1 @umax_uge(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -67,7 +67,7 @@ define i1 @umin_ult(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -95,7 +95,7 @@ define i1 @umin_ule(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -123,7 +123,7 @@ define i1 @smax_sgt(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -151,7 +151,7 @@ define i1 @smax_sge(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -179,7 +179,7 @@ define i1 @smin_slt(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -207,7 +207,7 @@ define i1 @smin_sle(i32 %x, i32 %y) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP2]], true
 ; CHECK-NEXT:    ret i1 [[RET]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -235,7 +235,7 @@ define i1 @umax_uge_ugt_with_add_nuw(i32 %x, i32 %y) {
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
-; CHECK-NEXT:    ret i1 [[CMP2]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
 ;
@@ -297,7 +297,7 @@ define i1 @umax_ugt_ugt_both(i32 %x, i32 %y, i32 %z) {
 ; CHECK:       if:
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[Z]], [[X]]
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp ugt i32 [[Z]], [[Y]]
-; CHECK-NEXT:    [[AND:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT:    [[AND:%.*]] = xor i1 true, true
 ; CHECK-NEXT:    ret i1 [[AND]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i1 false
@@ -323,7 +323,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) {
 ; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]]
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, false
 ; CHECK-NEXT:    ret i1 [[RET]]
 ;
 entry:


        


More information about the llvm-commits mailing list