[llvm] [ValueTracking] Support dominating known bits condition in and/or (PR #74728)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 09:02:23 PST 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/74728

>From bae764a404d1fed0a844dcbd74d0ffebd74735ff Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 7 Dec 2023 11:28:37 +0100
Subject: [PATCH] [ValueTracking] Support dominating known bits condition in
 and/or

This extends computeKnownBits() support for dominating conditions
to also handle and/or conditions. We'll look through either and
or or depending on which edge we're considering.

This change is mainly for the sake of completeness, so we don't
start missing optimizations if SimplifyCFG decides to merge some
branches.
---
 llvm/lib/Analysis/DomConditionCache.cpp       | 48 ++++++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp           | 32 +++++++++----
 .../test/Transforms/InstCombine/known-bits.ll | 15 ++----
 .../Transforms/LoopVectorize/induction.ll     | 30 ++++++------
 4 files changed, 74 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index c7f4cab415888..3dad0c2e07133 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -34,23 +34,39 @@ static void findAffectedValues(Value *Cond,
     }
   };
 
-  ICmpInst::Predicate Pred;
-  Value *A;
-  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Constant()))) {
-    AddAffected(A);
+  bool TopLevelIsAnd = match(Cond, m_LogicalAnd());
+  SmallVector<Value *, 8> Worklist;
+  SmallPtrSet<Value *, 8> Visited;
+  Worklist.push_back(Cond);
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+    if (!Visited.insert(V).second)
+      continue;
 
-    if (ICmpInst::isEquality(Pred)) {
-      Value *X;
-      // (X & C) or (X | C) or (X ^ C).
-      // (X << C) or (X >>_s C) or (X >>_u C).
-      if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-          match(A, m_Shift(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
-    } else {
-      Value *X;
-      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
-      if (match(A, m_Add(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
+    ICmpInst::Predicate Pred;
+    Value *A, *B;
+    // Only recurse into and/or if it matches the top-level and/or type.
+    if (TopLevelIsAnd ? match(V, m_LogicalAnd(m_Value(A), m_Value(B)))
+                      : match(V, m_LogicalOr(m_Value(A), m_Value(B)))) {
+      Worklist.push_back(A);
+      Worklist.push_back(B);
+    } else if (match(V, m_ICmp(Pred, m_Value(A), m_Constant()))) {
+      AddAffected(A);
+
+      if (ICmpInst::isEquality(Pred)) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      } else {
+        Value *X;
+        // Handle (A + C1) u< C2, which is the canonical form of
+        // A > C3 && A < C4.
+        if (match(A, m_Add(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 58db81f470130..0e40a02fd4de6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -706,28 +706,40 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
   }
 }
 
+static void computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                     KnownBits &Known, unsigned Depth,
+                                     const SimplifyQuery &SQ, bool Invert) {
+  Value *A, *B;
+  if (Depth < MaxAnalysisRecursionDepth &&
+      (Invert ? match(Cond, m_LogicalOr(m_Value(A), m_Value(B)))
+              : match(Cond, m_LogicalAnd(m_Value(A), m_Value(B))))) {
+    computeKnownBitsFromCond(V, A, Known, Depth + 1, SQ, Invert);
+    computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
+  }
+
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
+    computeKnownBitsFromCmp(
+        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
+        Cmp->getOperand(0), Cmp->getOperand(1), Known, SQ);
+}
+
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                       unsigned Depth, const SimplifyQuery &Q) {
   if (!Q.CxtI)
     return;
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
-      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
-      if (!Cmp)
-        continue;
-
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getPredicate(), Cmp->getOperand(0),
-                                Cmp->getOperand(1), Known, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ false);
 
       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getInversePredicate(),
-                                Cmp->getOperand(0), Cmp->getOperand(1), Known,
-                                Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ true);
     }
 
     if (Known.hasConflict())
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index e346330aa5b1e..246579cc4cd0c 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -105,8 +105,7 @@ define i8 @test_cond_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[CMP]], [[C:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -133,8 +132,7 @@ define i8 @test_cond_and_commuted(i8 %x, i1 %c1, i1 %c2) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[C3]], [[CMP]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -161,8 +159,7 @@ define i8 @test_cond_logical_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -218,8 +215,7 @@ define i8 @test_cond_inv_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR1]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR2]]
+; CHECK-NEXT:    ret i8 -4
 ;
   %and = and i8 %x, 3
   %cmp = icmp ne i8 %and, 0
@@ -242,8 +238,7 @@ define i8 @test_cond_inv_logical_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP_NOT]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll
index 29d8719db9b29..50a5cc6774c5c 100644
--- a/llvm/test/Transforms/LoopVectorize/induction.ll
+++ b/llvm/test/Transforms/LoopVectorize/induction.ll
@@ -3523,10 +3523,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3589,10 +3589,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3733,10 +3733,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 1, i32 2, i32 3>
@@ -3907,11 +3907,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[EXT_MUL5:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; IND-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL5]], 2
+; IND-NEXT:    [[EXT_MUL5:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL5]], 2
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -3976,11 +3976,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; UNROLL-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -4126,11 +4126,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 4, i32 8, i32 12>



More information about the llvm-commits mailing list