[llvm] 92a11eb - [ConstraintElim] Add facts implied by MinMaxIntrinsic
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 24 00:04:52 PDT 2023
Author: Yingwei Zheng
Date: 2023-07-24T15:03:34+08:00
New Revision: 92a11eb32c92d10132f685f9896e8f044c4c2f02
URL: https://github.com/llvm/llvm-project/commit/92a11eb32c92d10132f685f9896e8f044c4c2f02
DIFF: https://github.com/llvm/llvm-project/commit/92a11eb32c92d10132f685f9896e8f044c4c2f02.diff
LOG: [ConstraintElim] Add facts implied by MinMaxIntrinsic
Fixes https://github.com/llvm/llvm-project/issues/63896 and https://github.com/rust-lang/rust/issues/113757.
This patch adds facts implied by llvm.smin/smax/umin/umax intrinsics.
Reviewed By: fhahn
Differential Revision: https://reviews.llvm.org/D155412
Added:
Modified:
llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
llvm/test/Transforms/ConstraintElimination/minmax.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index e5773438fe5935..15628d32280d8e 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -784,6 +784,11 @@ void State::addInfoFor(BasicBlock &BB) {
continue;
}
+ if (isa<MinMaxIntrinsic>(&I)) {
+ WorkList.push_back(FactOrCheck::getFact(DT.getNode(&BB), &I));
+ continue;
+ }
+
Value *Cond;
// For now, just handle assumes with a single compare as condition.
if (match(&I, m_Intrinsic<Intrinsic::assume>(m_Value(Cond))) &&
@@ -1363,22 +1368,14 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
}
LLVM_DEBUG(dbgs() << "fact to add to the system: " << *CB.Inst << "\n");
- ICmpInst::Predicate Pred;
- Value *A, *B;
- Value *Cmp = CB.Inst;
- match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
- if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+ auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
LLVM_DEBUG(
dbgs()
<< "Skip adding constraint because system has too many rows.\n");
- continue;
+ return;
}
- // Use the inverse predicate if required.
- if (CB.Not)
- Pred = CmpInst::getInversePredicate(Pred);
-
Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size())
ReproducerCondStack.emplace_back(Pred, A, B);
@@ -1394,6 +1391,25 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT,
nullptr, nullptr);
}
}
+ };
+
+ ICmpInst::Predicate Pred;
+ if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
+ Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
+ AddFact(Pred, MinMax, MinMax->getLHS());
+ AddFact(Pred, MinMax, MinMax->getRHS());
+ continue;
+ }
+
+ Value *A, *B;
+ Value *Cmp = CB.Inst;
+ match(Cmp, m_Intrinsic<Intrinsic::assume>(m_Value(Cmp)));
+ if (match(Cmp, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
+ // Use the inverse predicate if required.
+ if (CB.Not)
+ Pred = CmpInst::getInversePredicate(Pred);
+
+ AddFact(Pred, A, B);
}
}
diff --git a/llvm/test/Transforms/ConstraintElimination/minmax.ll b/llvm/test/Transforms/ConstraintElimination/minmax.ll
index 3dbd0f8f62d23c..43a9b6931b3292 100644
--- a/llvm/test/Transforms/ConstraintElimination/minmax.ll
+++ b/llvm/test/Transforms/ConstraintElimination/minmax.ll
@@ -11,7 +11,7 @@ define i1 @umax_ugt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -39,7 +39,7 @@ define i1 @umax_uge(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -67,7 +67,7 @@ define i1 @umin_ult(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -95,7 +95,7 @@ define i1 @umin_ule(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ult i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ule i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -123,7 +123,7 @@ define i1 @smax_sgt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -151,7 +151,7 @@ define i1 @smax_sge(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sge i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -179,7 +179,7 @@ define i1 @smin_slt(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -207,7 +207,7 @@ define i1 @smin_sle(i32 %x, i32 %y) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[Y]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp sle i32 [[Y]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP2]], true
; CHECK-NEXT: ret i1 [[RET]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -235,7 +235,7 @@ define i1 @umax_uge_ugt_with_add_nuw(i32 %x, i32 %y) {
; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]]
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Y]], [[X]]
-; CHECK-NEXT: ret i1 [[CMP2]]
+; CHECK-NEXT: ret i1 true
; CHECK: end:
; CHECK-NEXT: ret i1 false
;
@@ -297,7 +297,7 @@ define i1 @umax_ugt_ugt_both(i32 %x, i32 %y, i32 %z) {
; CHECK: if:
; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Z]], [[X]]
; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[Z]], [[Y]]
-; CHECK-NEXT: [[AND:%.*]] = xor i1 [[CMP2]], [[CMP3]]
+; CHECK-NEXT: [[AND:%.*]] = xor i1 true, true
; CHECK-NEXT: ret i1 [[AND]]
; CHECK: end:
; CHECK-NEXT: ret i1 false
@@ -323,7 +323,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) {
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
; CHECK-NEXT: [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]]
; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]]
-; CHECK-NEXT: [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]]
+; CHECK-NEXT: [[RET:%.*]] = xor i1 true, false
; CHECK-NEXT: ret i1 [[RET]]
;
entry:
More information about the llvm-commits
mailing list