[llvm] [ConstraintElim] Add facts implied by intrinsics if they are used by other constraints (PR #80121)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 4 05:17:15 PST 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/80121

>From 3ac2932ac2954c60b091830cdfd0c7d5c44eed7e Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 31 Jan 2024 17:14:37 +0800
Subject: [PATCH] [ConstraintElim] Add facts implied by intrinsics if they are
 used by other constraints

---
 llvm/include/llvm/Analysis/ConstraintSystem.h | 49 +++++++++++----
 llvm/lib/Analysis/ConstraintSystem.cpp        | 37 +++++++----
 .../Scalar/ConstraintElimination.cpp          | 62 +++++++++++--------
 .../ConstraintElimination/minmax.ll           |  4 +-
 4 files changed, 98 insertions(+), 54 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ConstraintSystem.h b/llvm/include/llvm/Analysis/ConstraintSystem.h
index 7b02b618f7cb4..24fd2a8918065 100644
--- a/llvm/include/llvm/Analysis/ConstraintSystem.h
+++ b/llvm/include/llvm/Analysis/ConstraintSystem.h
@@ -35,11 +35,19 @@ class ConstraintSystem {
     return 0;
   }
 
-  static int64_t getLastCoefficient(ArrayRef<Entry> Row, uint16_t Id) {
-    if (Row.empty())
+  struct ConstraintRow {
+    SmallVector<Entry, 8> Entries;
+    bool IsWellDefined;
+
+    ConstraintRow(SmallVector<Entry, 8> Entries, bool IsWellDefined)
+        : Entries(std::move(Entries)), IsWellDefined(IsWellDefined) {}
+  };
+
+  static int64_t getLastCoefficient(const ConstraintRow &Row, uint16_t Id) {
+    if (Row.Entries.empty())
       return 0;
-    if (Row.back().Id == Id)
-      return Row.back().Coefficient;
+    if (Row.Entries.back().Id == Id)
+      return Row.Entries.back().Coefficient;
     return 0;
   }
 
@@ -48,12 +56,16 @@ class ConstraintSystem {
   /// Current linear constraints in the system.
   /// An entry of the form c0, c1, ... cn represents the following constraint:
   ///   c0 >= v0 * c1 + .... + v{n-1} * cn
-  SmallVector<SmallVector<Entry, 8>, 4> Constraints;
+  SmallVector<ConstraintRow, 4> Constraints;
 
   /// A map of variables (IR values) to their corresponding index in the
   /// constraint system.
   DenseMap<Value *, unsigned> Value2Index;
 
+  /// A map of index to the count of the corresponding variable used by
+  /// well-defined constraints.
+  DenseMap<unsigned, unsigned> WellDefinedVariableRefCount;
+
   // Eliminate constraints from the system using Fourier–Motzkin elimination.
   bool eliminateUsingFM();
 
@@ -74,14 +86,14 @@ class ConstraintSystem {
   ConstraintSystem(const DenseMap<Value *, unsigned> &Value2Index)
       : NumVariables(Value2Index.size()), Value2Index(Value2Index) {}
 
-  bool addVariableRow(ArrayRef<int64_t> R) {
+  bool addVariableRow(ArrayRef<int64_t> R, bool IsWellDefined = true) {
     assert(Constraints.empty() || R.size() == NumVariables);
     // If all variable coefficients are 0, the constraint does not provide any
     // usable information.
     if (all_of(ArrayRef(R).drop_front(1), [](int64_t C) { return C == 0; }))
       return false;
 
-    SmallVector<Entry, 4> NewRow;
+    SmallVector<Entry, 8> NewRow;
     for (const auto &[Idx, C] : enumerate(R)) {
       if (C == 0)
         continue;
@@ -89,7 +101,12 @@ class ConstraintSystem {
     }
     if (Constraints.empty())
       NumVariables = R.size();
-    Constraints.push_back(std::move(NewRow));
+    Constraints.emplace_back(std::move(NewRow), IsWellDefined);
+    if (IsWellDefined) {
+      for (auto &Entry : Constraints.back().Entries)
+        if (Entry.Id != 0)
+          ++WellDefinedVariableRefCount[Entry.Id];
+    }
     return true;
   }
 
@@ -98,14 +115,14 @@ class ConstraintSystem {
     return Value2Index;
   }
 
-  bool addVariableRowFill(ArrayRef<int64_t> R) {
+  bool addVariableRowFill(ArrayRef<int64_t> R, bool IsWellDefined = true) {
     // If all variable coefficients are 0, the constraint does not provide any
     // usable information.
     if (all_of(ArrayRef(R).drop_front(1), [](int64_t C) { return C == 0; }))
       return false;
 
     NumVariables = std::max(R.size(), NumVariables);
-    return addVariableRow(R);
+    return addVariableRow(R, IsWellDefined);
   }
 
   /// Returns true if there may be a solution for the constraints in the system.
@@ -147,12 +164,20 @@ class ConstraintSystem {
   SmallVector<int64_t> getLastConstraint() const {
     assert(!Constraints.empty() && "Constraint system is empty");
     SmallVector<int64_t> Result(NumVariables, 0);
-    for (auto &Entry : Constraints.back())
+    for (auto &Entry : Constraints.back().Entries)
       Result[Entry.Id] = Entry.Coefficient;
     return Result;
   }
 
-  void popLastConstraint() { Constraints.pop_back(); }
+  void popLastConstraint() {
+    if (Constraints.back().IsWellDefined) {
+      for (auto &Entry : Constraints.back().Entries)
+        if (Entry.Id != 0)
+          --WellDefinedVariableRefCount[Entry.Id];
+    }
+    Constraints.pop_back();
+  }
+
   void popLastNVariables(unsigned N) {
     assert(NumVariables > N);
     NumVariables -= N;
diff --git a/llvm/lib/Analysis/ConstraintSystem.cpp b/llvm/lib/Analysis/ConstraintSystem.cpp
index 1a9c7c21e9ced..99f7f739a0602 100644
--- a/llvm/lib/Analysis/ConstraintSystem.cpp
+++ b/llvm/lib/Analysis/ConstraintSystem.cpp
@@ -33,12 +33,12 @@ bool ConstraintSystem::eliminateUsingFM() {
 
   // First, either remove the variable in place if it is 0 or add the row to
   // RemainingRows and remove it from the system.
-  SmallVector<SmallVector<Entry, 8>, 4> RemainingRows;
+  SmallVector<ConstraintRow, 4> RemainingRows;
   for (unsigned R1 = 0; R1 < Constraints.size();) {
-    SmallVector<Entry, 8> &Row1 = Constraints[R1];
+    ConstraintRow &Row1 = Constraints[R1];
     if (getLastCoefficient(Row1, LastIdx) == 0) {
-      if (Row1.size() > 0 && Row1.back().Id == LastIdx)
-        Row1.pop_back();
+      if (Row1.Entries.size() > 0 && Row1.Entries.back().Id == LastIdx)
+        Row1.Entries.pop_back();
       R1++;
     } else {
       std::swap(Constraints[R1], Constraints.back());
@@ -74,8 +74,8 @@ bool ConstraintSystem::eliminateUsingFM() {
       SmallVector<Entry, 8> NR;
       unsigned IdxUpper = 0;
       unsigned IdxLower = 0;
-      auto &LowerRow = RemainingRows[LowerR];
-      auto &UpperRow = RemainingRows[UpperR];
+      auto &LowerRow = RemainingRows[LowerR].Entries;
+      auto &UpperRow = RemainingRows[UpperR].Entries;
       while (true) {
         if (IdxUpper >= UpperRow.size() || IdxLower >= LowerRow.size())
           break;
@@ -112,7 +112,7 @@ bool ConstraintSystem::eliminateUsingFM() {
       }
       if (NR.empty())
         continue;
-      Constraints.push_back(std::move(NR));
+      Constraints.emplace_back(std::move(NR), /*IsWellDefined*/ true);
       // Give up if the new system gets too big.
       if (Constraints.size() > 500)
         return false;
@@ -124,6 +124,10 @@ bool ConstraintSystem::eliminateUsingFM() {
 }
 
 bool ConstraintSystem::mayHaveSolutionImpl() {
+  // Make sure that all variables in the system are well defined.
+  assert(all_of(Constraints,
+                [](const ConstraintRow &Row) { return Row.IsWellDefined; }));
+
   while (!Constraints.empty() && NumVariables > 1) {
     if (!eliminateUsingFM())
       return true;
@@ -133,10 +137,10 @@ bool ConstraintSystem::mayHaveSolutionImpl() {
     return true;
 
   return all_of(Constraints, [](auto &R) {
-    if (R.empty())
+    if (R.Entries.empty())
       return true;
-    if (R[0].Id == 0)
-      return R[0].Coefficient >= 0;
+    if (R.Entries[0].Id == 0)
+      return R.Entries[0].Coefficient >= 0;
     return true;
   });
 }
@@ -161,7 +165,7 @@ void ConstraintSystem::dump() const {
   if (Constraints.empty())
     return;
   SmallVector<std::string> Names = getVarNamesList();
-  for (const auto &Row : Constraints) {
+  for (const auto &[Row, IsWellDefined] : Constraints) {
     SmallVector<std::string, 16> Parts;
     for (unsigned I = 0, S = Row.size(); I < S; ++I) {
       if (Row[I].Id >= NumVariables)
@@ -204,6 +208,15 @@ bool ConstraintSystem::isConditionImplied(SmallVector<int64_t, 8> R) const {
     return false;
 
   auto NewSystem = *this;
-  NewSystem.addVariableRow(R);
+  NewSystem.addVariableRow(R, /*IsWellDefined*/ true);
+  // Remove invalid constraints whose variables may be poison.
+  erase_if(NewSystem.Constraints, [&](ConstraintRow &R) {
+    if (!R.IsWellDefined) {
+      R.IsWellDefined = all_of(R.Entries, [&](const Entry &E) {
+        return E.Id == 0 || NewSystem.WellDefinedVariableRefCount[E.Id] > 0;
+      });
+    }
+    return !R.IsWellDefined;
+  });
   return !NewSystem.mayHaveSolution();
 }
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index 97cf5ebe3ca06..bc83179b4e97e 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -306,7 +306,8 @@ class ConstraintInfo {
   bool doesHold(CmpInst::Predicate Pred, Value *A, Value *B) const;
 
   void addFact(CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn,
-               unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack);
+               unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack,
+               bool IsWellDefined);
 
   /// Turn a comparison of the form \p Op0 \p Pred \p Op1 into a vector of
   /// constraints, using indices from the corresponding constraint system.
@@ -329,7 +330,8 @@ class ConstraintInfo {
   /// system if \p Pred is signed/unsigned.
   void transferToOtherSystem(CmpInst::Predicate Pred, Value *A, Value *B,
                              unsigned NumIn, unsigned NumOut,
-                             SmallVectorImpl<StackEntry> &DFSInStack);
+                             SmallVectorImpl<StackEntry> &DFSInStack,
+                             bool IsWellDefined);
 };
 
 /// Represents a (Coefficient * Variable) entry after IR decomposition.
@@ -819,7 +821,8 @@ bool ConstraintInfo::doesHold(CmpInst::Predicate Pred, Value *A,
 
 void ConstraintInfo::transferToOtherSystem(
     CmpInst::Predicate Pred, Value *A, Value *B, unsigned NumIn,
-    unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack) {
+    unsigned NumOut, SmallVectorImpl<StackEntry> &DFSInStack,
+    bool IsWellDefined) {
   auto IsKnownNonNegative = [this](Value *V) {
     return doesHold(CmpInst::ICMP_SGE, V, ConstantInt::get(V->getType(), 0)) ||
            isKnownNonNegative(V, DL, /*Depth=*/MaxAnalysisRecursionDepth - 1);
@@ -839,9 +842,9 @@ void ConstraintInfo::transferToOtherSystem(
     //  If B is a signed positive constant, then A >=s 0 and A <s (or <=s) B.
     if (IsKnownNonNegative(B)) {
       addFact(CmpInst::ICMP_SGE, A, ConstantInt::get(B->getType(), 0), NumIn,
-              NumOut, DFSInStack);
+              NumOut, DFSInStack, IsWellDefined);
       addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut,
-              DFSInStack);
+              DFSInStack, IsWellDefined);
     }
     break;
   case CmpInst::ICMP_UGE:
@@ -849,27 +852,30 @@ void ConstraintInfo::transferToOtherSystem(
     //  If A is a signed positive constant, then B >=s 0 and A >s (or >=s) B.
     if (IsKnownNonNegative(A)) {
       addFact(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), 0), NumIn,
-              NumOut, DFSInStack);
+              NumOut, DFSInStack, IsWellDefined);
       addFact(CmpInst::getSignedPredicate(Pred), A, B, NumIn, NumOut,
-              DFSInStack);
+              DFSInStack, IsWellDefined);
     }
     break;
   case CmpInst::ICMP_SLT:
     if (IsKnownNonNegative(A))
-      addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack);
+      addFact(CmpInst::ICMP_ULT, A, B, NumIn, NumOut, DFSInStack,
+              IsWellDefined);
     break;
   case CmpInst::ICMP_SGT: {
     if (doesHold(CmpInst::ICMP_SGE, B, ConstantInt::get(B->getType(), -1)))
       addFact(CmpInst::ICMP_UGE, A, ConstantInt::get(B->getType(), 0), NumIn,
-              NumOut, DFSInStack);
+              NumOut, DFSInStack, IsWellDefined);
     if (IsKnownNonNegative(B))
-      addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack);
+      addFact(CmpInst::ICMP_UGT, A, B, NumIn, NumOut, DFSInStack,
+              IsWellDefined);
 
     break;
   }
   case CmpInst::ICMP_SGE:
     if (IsKnownNonNegative(B))
-      addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack);
+      addFact(CmpInst::ICMP_UGE, A, B, NumIn, NumOut, DFSInStack,
+              IsWellDefined);
     break;
   }
 }
@@ -1068,10 +1074,6 @@ void State::addInfoFor(BasicBlock &BB) {
       // TODO: handle llvm.abs as well
       WorkList.push_back(
           FactOrCheck::getCheck(DT.getNode(&BB), cast<CallInst>(&I)));
-      // TODO: Check if it is possible to instead only added the min/max facts
-      // when simplifying uses of the min/max intrinsics.
-      if (!isGuaranteedNotToBePoison(&I))
-        break;
       [[fallthrough]];
     case Intrinsic::abs:
       WorkList.push_back(FactOrCheck::getInstFact(DT.getNode(&BB), &I));
@@ -1464,7 +1466,8 @@ static bool checkOrAndOpImpliedByOther(
 
   // Optimistically add fact from first condition.
   unsigned OldSize = DFSInStack.size();
-  Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+  Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack,
+               /*IsWellDefined*/ true);
   if (OldSize == DFSInStack.size())
     return false;
 
@@ -1496,7 +1499,8 @@ static bool checkOrAndOpImpliedByOther(
 
 void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
                              unsigned NumIn, unsigned NumOut,
-                             SmallVectorImpl<StackEntry> &DFSInStack) {
+                             SmallVectorImpl<StackEntry> &DFSInStack,
+                             bool IsWellDefined) {
   // If the constraint has a pre-condition, skip the constraint if it does not
   // hold.
   SmallVector<Value *> NewVariables;
@@ -1513,7 +1517,7 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
   if (R.Coefficients.empty())
     return;
 
-  Added |= CSToUse.addVariableRowFill(R.Coefficients);
+  Added |= CSToUse.addVariableRowFill(R.Coefficients, IsWellDefined);
 
   // If R has been added to the system, add the new variables and queue it for
   // removal once it goes out-of-scope.
@@ -1539,7 +1543,7 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
         ConstraintTy VarPos(SmallVector<int64_t, 8>(Value2Index.size() + 1, 0),
                             false, false, false);
         VarPos.Coefficients[Value2Index[V]] = -1;
-        CSToUse.addVariableRow(VarPos.Coefficients);
+        CSToUse.addVariableRow(VarPos.Coefficients, IsWellDefined);
         DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned,
                                 SmallVector<Value *, 2>());
       }
@@ -1549,7 +1553,7 @@ void ConstraintInfo::addFact(CmpInst::Predicate Pred, Value *A, Value *B,
       // Also add the inverted constraint for equality constraints.
       for (auto &Coeff : R.Coefficients)
         Coeff *= -1;
-      CSToUse.addVariableRowFill(R.Coefficients);
+      CSToUse.addVariableRowFill(R.Coefficients, IsWellDefined);
 
       DFSInStack.emplace_back(NumIn, NumOut, R.IsSigned,
                               SmallVector<Value *, 2>());
@@ -1724,7 +1728,8 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
       continue;
     }
 
-    auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B) {
+    auto AddFact = [&](CmpInst::Predicate Pred, Value *A, Value *B,
+                       bool IsWellDefined) {
       LLVM_DEBUG(dbgs() << "Processing fact to add to the system: ";
                  dumpUnpackedICmp(dbgs(), Pred, A, B); dbgs() << "\n");
       if (Info.getCS(CmpInst::isSigned(Pred)).size() > MaxRows) {
@@ -1734,11 +1739,12 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
         return;
       }
 
-      Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+      Info.addFact(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack, IsWellDefined);
       if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size())
         ReproducerCondStack.emplace_back(Pred, A, B);
 
-      Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack);
+      Info.transferToOtherSystem(Pred, A, B, CB.NumIn, CB.NumOut, DFSInStack,
+                                 IsWellDefined);
       if (ReproducerModule && DFSInStack.size() > ReproducerCondStack.size()) {
         // Add dummy entries to ReproducerCondStack to keep it in sync with
         // DFSInStack.
@@ -1756,14 +1762,16 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
       Value *X;
       if (match(CB.Inst, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) {
         // TODO: Add CB.Inst >= 0 fact.
-        AddFact(CmpInst::ICMP_SGE, CB.Inst, X);
+        bool IsWellDefined = isGuaranteedNotToBePoison(CB.Inst);
+        AddFact(CmpInst::ICMP_SGE, CB.Inst, X, IsWellDefined);
         continue;
       }
 
       if (auto *MinMax = dyn_cast<MinMaxIntrinsic>(CB.Inst)) {
         Pred = ICmpInst::getNonStrictPredicate(MinMax->getPredicate());
-        AddFact(Pred, MinMax, MinMax->getLHS());
-        AddFact(Pred, MinMax, MinMax->getRHS());
+        bool IsWellDefined = isGuaranteedNotToBePoison(CB.Inst);
+        AddFact(Pred, MinMax, MinMax->getLHS(), IsWellDefined);
+        AddFact(Pred, MinMax, MinMax->getRHS(), IsWellDefined);
         continue;
       }
     }
@@ -1791,7 +1799,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
       (void)Matched;
       assert(Matched && "Must have an assume intrinsic with a icmp operand");
     }
-    AddFact(Pred, A, B);
+    AddFact(Pred, A, B, /*IsWellDefined*/ true);
   }
 
   if (ReproducerModule && !ReproducerModule->functions().empty()) {
diff --git a/llvm/test/Transforms/ConstraintElimination/minmax.ll b/llvm/test/Transforms/ConstraintElimination/minmax.ll
index 68513ea10ad0f..bbb337eaf3e26 100644
--- a/llvm/test/Transforms/ConstraintElimination/minmax.ll
+++ b/llvm/test/Transforms/ConstraintElimination/minmax.ll
@@ -306,9 +306,7 @@ define i1 @smin_branchless(i32 %x, i32 %y) {
 ; CHECK-SAME: (i32 [[X:%.*]], i32 [[Y:%.*]]) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[MIN:%.*]] = call i32 @llvm.smin.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sle i32 [[MIN]], [[X]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i32 [[MIN]], [[X]]
-; CHECK-NEXT:    [[RET:%.*]] = xor i1 [[CMP1]], [[CMP2]]
+; CHECK-NEXT:    [[RET:%.*]] = xor i1 true, false
 ; CHECK-NEXT:    ret i1 [[RET]]
 ;
 entry:



More information about the llvm-commits mailing list