[llvm] [ValueTracking] Support dominating known bits condition in and/or (PR #74728)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 7 07:38:08 PST 2023


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/74728

This extends computeKnownBits() support for dominating conditions to also handle and/or conditions. We'll look through either and or or depending on which edge we're considering.

This change is mainly for the sake of completeness, so we don't start missing optimizations if SimplifyCFG decides to merge some branches.

>From fe10f80641598818bbe88e519171c046d79573b6 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 7 Dec 2023 11:28:37 +0100
Subject: [PATCH] [ValueTracking] Support dominating known bits condition in
 and/or

This extends computeKnownBits() support for dominating conditions
to also handle and/or conditions. We'll look through either and
or or depending on which edge we're considering.

This change is mainly for the sake of completeness, so we don't
start missing optimizations if SimplifyCFG decides to merge some
branches.
---
 llvm/lib/Analysis/DomConditionCache.cpp       | 50 ++++++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp           | 32 ++++++++----
 .../test/Transforms/InstCombine/known-bits.ll | 15 ++----
 .../Transforms/LoopVectorize/induction.ll     | 30 +++++------
 4 files changed, 75 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 351881fe9e1f9..4a1eea57ed5bf 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -34,24 +34,40 @@ static void findAffectedValues(Value *Cond,
     }
   };
 
-  ICmpInst::Predicate Pred;
-  Value *A;
-  Constant *C;
-  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Constant(C)))) {
-    AddAffected(A);
+  bool TopLevelIsAnd = match(Cond, m_LogicalAnd());
+  SmallVector<Value *, 8> Worklist;
+  SmallPtrSet<Value *, 8> Visited;
+  Worklist.push_back(Cond);
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+    if (!Visited.insert(V).second)
+      continue;
 
-    if (ICmpInst::isEquality(Pred)) {
-      Value *X;
-      // (X & C) or (X | C) or (X ^ C).
-      // (X << C) or (X >>_s C) or (X >>_u C).
-      if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-          match(A, m_Shift(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
-    } else {
-      Value *X;
-      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
-      if (match(A, m_Add(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
+    ICmpInst::Predicate Pred;
+    Value *A, *B;
+    Constant *C;
+    // Only recurse into and/or if it matches the top-level and/or type.
+    if (TopLevelIsAnd ? match(V, m_LogicalAnd(m_Value(A), m_Value(B)))
+                      : match(V, m_LogicalOr(m_Value(A), m_Value(B)))) {
+      Worklist.push_back(A);
+      Worklist.push_back(B);
+    } else if (match(V, m_ICmp(Pred, m_Value(A), m_Constant(C)))) {
+      AddAffected(A);
+
+      if (ICmpInst::isEquality(Pred)) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      } else {
+        Value *X;
+        // Handle (A + C1) u< C2, which is the canonical form of
+        // A > C3 && A < C4.
+        if (match(A, m_Add(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2c52107dab23b..7db141a0ea758 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -705,28 +705,40 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
   }
 }
 
+static void computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                     KnownBits &Known, unsigned Depth,
+                                     const SimplifyQuery &SQ, bool Invert) {
+  Value *A, *B;
+  if (Depth < MaxAnalysisRecursionDepth &&
+      (Invert ? match(Cond, m_LogicalOr(m_Value(A), m_Value(B)))
+              : match(Cond, m_LogicalAnd(m_Value(A), m_Value(B))))) {
+    computeKnownBitsFromCond(V, A, Known, Depth + 1, SQ, Invert);
+    computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
+  }
+
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
+    computeKnownBitsFromCmp(
+        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
+        Cmp->getOperand(0), Cmp->getOperand(1), Known, Depth, SQ);
+}
+
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                       unsigned Depth, const SimplifyQuery &Q) {
   if (!Q.CxtI)
     return;
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
-      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
-      if (!Cmp)
-        continue;
-
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getPredicate(), Cmp->getOperand(0),
-                                Cmp->getOperand(1), Known, Depth, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ false);
 
       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getInversePredicate(),
-                                Cmp->getOperand(0), Cmp->getOperand(1), Known,
-                                Depth, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ true);
     }
 
     if (Known.hasConflict())
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index e346330aa5b1e..246579cc4cd0c 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -105,8 +105,7 @@ define i8 @test_cond_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[CMP]], [[C:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -133,8 +132,7 @@ define i8 @test_cond_and_commuted(i8 %x, i1 %c1, i1 %c2) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[C3]], [[CMP]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -161,8 +159,7 @@ define i8 @test_cond_logical_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -218,8 +215,7 @@ define i8 @test_cond_inv_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR1]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR2]]
+; CHECK-NEXT:    ret i8 -4
 ;
   %and = and i8 %x, 3
   %cmp = icmp ne i8 %and, 0
@@ -242,8 +238,7 @@ define i8 @test_cond_inv_logical_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP_NOT]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll
index a8cfac64258e8..9e7648b29cfa5 100644
--- a/llvm/test/Transforms/LoopVectorize/induction.ll
+++ b/llvm/test/Transforms/LoopVectorize/induction.ll
@@ -3525,10 +3525,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3591,10 +3591,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3735,10 +3735,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 1, i32 2, i32 3>
@@ -3909,11 +3909,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[EXT_MUL5:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; IND-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL5]], 2
+; IND-NEXT:    [[EXT_MUL5:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL5]], 2
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -3978,11 +3978,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; UNROLL-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -4128,11 +4128,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 4, i32 8, i32 12>



More information about the llvm-commits mailing list