[llvm] [ValueTracking] Support dominating known bits condition in and/or (PR #74728)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 7 07:38:39 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

This extends computeKnownBits() support for dominating conditions to also handle and/or conditions. We'll look through either and or or depending on which edge we're considering.

This change is mainly for the sake of completeness, so we don't start missing optimizations if SimplifyCFG decides to merge some branches.

---
Full diff: https://github.com/llvm/llvm-project/pull/74728.diff


4 Files Affected:

- (modified) llvm/lib/Analysis/DomConditionCache.cpp (+33-17) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+22-10) 
- (modified) llvm/test/Transforms/InstCombine/known-bits.ll (+5-10) 
- (modified) llvm/test/Transforms/LoopVectorize/induction.ll (+15-15) 


``````````diff
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 351881fe9e1f9..4a1eea57ed5bf 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -34,24 +34,40 @@ static void findAffectedValues(Value *Cond,
     }
   };
 
-  ICmpInst::Predicate Pred;
-  Value *A;
-  Constant *C;
-  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Constant(C)))) {
-    AddAffected(A);
+  bool TopLevelIsAnd = match(Cond, m_LogicalAnd());
+  SmallVector<Value *, 8> Worklist;
+  SmallPtrSet<Value *, 8> Visited;
+  Worklist.push_back(Cond);
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+    if (!Visited.insert(V).second)
+      continue;
 
-    if (ICmpInst::isEquality(Pred)) {
-      Value *X;
-      // (X & C) or (X | C) or (X ^ C).
-      // (X << C) or (X >>_s C) or (X >>_u C).
-      if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-          match(A, m_Shift(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
-    } else {
-      Value *X;
-      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
-      if (match(A, m_Add(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
+    ICmpInst::Predicate Pred;
+    Value *A, *B;
+    Constant *C;
+    // Only recurse into and/or if it matches the top-level and/or type.
+    if (TopLevelIsAnd ? match(V, m_LogicalAnd(m_Value(A), m_Value(B)))
+                      : match(V, m_LogicalOr(m_Value(A), m_Value(B)))) {
+      Worklist.push_back(A);
+      Worklist.push_back(B);
+    } else if (match(V, m_ICmp(Pred, m_Value(A), m_Constant(C)))) {
+      AddAffected(A);
+
+      if (ICmpInst::isEquality(Pred)) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      } else {
+        Value *X;
+        // Handle (A + C1) u< C2, which is the canonical form of
+        // A > C3 && A < C4.
+        if (match(A, m_Add(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2c52107dab23b..7db141a0ea758 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -705,28 +705,40 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
   }
 }
 
+static void computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                     KnownBits &Known, unsigned Depth,
+                                     const SimplifyQuery &SQ, bool Invert) {
+  Value *A, *B;
+  if (Depth < MaxAnalysisRecursionDepth &&
+      (Invert ? match(Cond, m_LogicalOr(m_Value(A), m_Value(B)))
+              : match(Cond, m_LogicalAnd(m_Value(A), m_Value(B))))) {
+    computeKnownBitsFromCond(V, A, Known, Depth + 1, SQ, Invert);
+    computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
+  }
+
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
+    computeKnownBitsFromCmp(
+        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
+        Cmp->getOperand(0), Cmp->getOperand(1), Known, Depth, SQ);
+}
+
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                       unsigned Depth, const SimplifyQuery &Q) {
   if (!Q.CxtI)
     return;
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
-      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
-      if (!Cmp)
-        continue;
-
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getPredicate(), Cmp->getOperand(0),
-                                Cmp->getOperand(1), Known, Depth, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ false);
 
       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getInversePredicate(),
-                                Cmp->getOperand(0), Cmp->getOperand(1), Known,
-                                Depth, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ true);
     }
 
     if (Known.hasConflict())
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index e346330aa5b1e..246579cc4cd0c 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -105,8 +105,7 @@ define i8 @test_cond_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[CMP]], [[C:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -133,8 +132,7 @@ define i8 @test_cond_and_commuted(i8 %x, i1 %c1, i1 %c2) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[C3]], [[CMP]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -161,8 +159,7 @@ define i8 @test_cond_logical_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -218,8 +215,7 @@ define i8 @test_cond_inv_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR1]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR2]]
+; CHECK-NEXT:    ret i8 -4
 ;
   %and = and i8 %x, 3
   %cmp = icmp ne i8 %and, 0
@@ -242,8 +238,7 @@ define i8 @test_cond_inv_logical_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP_NOT]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll
index a8cfac64258e8..9e7648b29cfa5 100644
--- a/llvm/test/Transforms/LoopVectorize/induction.ll
+++ b/llvm/test/Transforms/LoopVectorize/induction.ll
@@ -3525,10 +3525,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3591,10 +3591,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3735,10 +3735,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 1, i32 2, i32 3>
@@ -3909,11 +3909,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[EXT_MUL5:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; IND-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL5]], 2
+; IND-NEXT:    [[EXT_MUL5:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL5]], 2
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -3978,11 +3978,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; UNROLL-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -4128,11 +4128,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 4, i32 8, i32 12>

``````````

</details>


https://github.com/llvm/llvm-project/pull/74728


More information about the llvm-commits mailing list