[llvm] [ValueTracking] Support dominating known bits condition in and/or (PR #74728)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 8 01:28:30 PST 2023


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/74728

>From 01448920bb40392cb0f09e4a736ab199a65cdebb Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 7 Dec 2023 11:28:37 +0100
Subject: [PATCH] [ValueTracking] Support dominating known bits condition in
 and/or

This extends computeKnownBits() support for dominating conditions
to also handle and/or conditions. We'll look through either and
or or depending on which edge we're considering.

This change is mainly for the sake of completeness, so we don't
start missing optimizations if SimplifyCFG decides to merge some
branches.
---
 llvm/lib/Analysis/DomConditionCache.cpp       | 48 ++++++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp           | 32 +++++++++----
 .../test/Transforms/InstCombine/known-bits.ll | 15 ++----
 .../Transforms/LoopVectorize/induction.ll     | 30 ++++++------
 4 files changed, 74 insertions(+), 51 deletions(-)

diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index c7f4cab415888..3dad0c2e07133 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -34,23 +34,39 @@ static void findAffectedValues(Value *Cond,
     }
   };
 
-  ICmpInst::Predicate Pred;
-  Value *A;
-  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Constant()))) {
-    AddAffected(A);
+  bool TopLevelIsAnd = match(Cond, m_LogicalAnd());
+  SmallVector<Value *, 8> Worklist;
+  SmallPtrSet<Value *, 8> Visited;
+  Worklist.push_back(Cond);
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+    if (!Visited.insert(V).second)
+      continue;
 
-    if (ICmpInst::isEquality(Pred)) {
-      Value *X;
-      // (X & C) or (X | C) or (X ^ C).
-      // (X << C) or (X >>_s C) or (X >>_u C).
-      if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-          match(A, m_Shift(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
-    } else {
-      Value *X;
-      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4.
-      if (match(A, m_Add(m_Value(X), m_ConstantInt())))
-        AddAffected(X);
+    ICmpInst::Predicate Pred;
+    Value *A, *B;
+    // Only recurse into and/or if it matches the top-level and/or type.
+    if (TopLevelIsAnd ? match(V, m_LogicalAnd(m_Value(A), m_Value(B)))
+                      : match(V, m_LogicalOr(m_Value(A), m_Value(B)))) {
+      Worklist.push_back(A);
+      Worklist.push_back(B);
+    } else if (match(V, m_ICmp(Pred, m_Value(A), m_Constant()))) {
+      AddAffected(A);
+
+      if (ICmpInst::isEquality(Pred)) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      } else {
+        Value *X;
+        // Handle (A + C1) u< C2, which is the canonical form of
+        // A > C3 && A < C4.
+        if (match(A, m_Add(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     }
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index eb2ee045146cf..2eb3b17ccffb1 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -705,28 +705,40 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
   }
 }
 
+static void computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                     KnownBits &Known, unsigned Depth,
+                                     const SimplifyQuery &SQ, bool Invert) {
+  Value *A, *B;
+  if (Depth < MaxAnalysisRecursionDepth &&
+      (Invert ? match(Cond, m_LogicalOr(m_Value(A), m_Value(B)))
+              : match(Cond, m_LogicalAnd(m_Value(A), m_Value(B))))) {
+    computeKnownBitsFromCond(V, A, Known, Depth + 1, SQ, Invert);
+    computeKnownBitsFromCond(V, B, Known, Depth + 1, SQ, Invert);
+  }
+
+  if (auto *Cmp = dyn_cast<ICmpInst>(Cond))
+    computeKnownBitsFromCmp(
+        V, Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(),
+        Cmp->getOperand(0), Cmp->getOperand(1), Known, SQ);
+}
+
 void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
-                                      unsigned Depth, const SimplifyQuery &Q) {
+                                       unsigned Depth, const SimplifyQuery &Q) {
   if (!Q.CxtI)
     return;
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
     for (BranchInst *BI : Q.DC->conditionsFor(V)) {
-      auto *Cmp = dyn_cast<ICmpInst>(BI->getCondition());
-      if (!Cmp)
-        continue;
-
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getPredicate(), Cmp->getOperand(0),
-                                Cmp->getOperand(1), Known, Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ false);
 
       BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
       if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
-        computeKnownBitsFromCmp(V, Cmp->getInversePredicate(),
-                                Cmp->getOperand(0), Cmp->getOperand(1), Known,
-                                Q);
+        computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
+                                 /*Invert*/ true);
     }
 
     if (Known.hasConflict())
diff --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index e346330aa5b1e..246579cc4cd0c 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -105,8 +105,7 @@ define i8 @test_cond_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[CMP]], [[C:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -133,8 +132,7 @@ define i8 @test_cond_and_commuted(i8 %x, i1 %c1, i1 %c2) {
 ; CHECK-NEXT:    [[COND:%.*]] = and i1 [[C3]], [[CMP]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -161,8 +159,7 @@ define i8 @test_cond_logical_and(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
@@ -218,8 +215,7 @@ define i8 @test_cond_inv_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR1]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR2]]
+; CHECK-NEXT:    ret i8 -4
 ;
   %and = and i8 %x, 3
   %cmp = icmp ne i8 %and, 0
@@ -242,8 +238,7 @@ define i8 @test_cond_inv_logical_or(i8 %x, i1 %c) {
 ; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP_NOT]], i1 [[C:%.*]], i1 false
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF:%.*]], label [[EXIT:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OR1:%.*]] = or i8 [[X]], -4
-; CHECK-NEXT:    ret i8 [[OR1]]
+; CHECK-NEXT:    ret i8 -4
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[OR2:%.*]] = or i8 [[X]], -4
 ; CHECK-NEXT:    ret i8 [[OR2]]
diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll
index a8cfac64258e8..9e7648b29cfa5 100644
--- a/llvm/test/Transforms/LoopVectorize/induction.ll
+++ b/llvm/test/Transforms/LoopVectorize/induction.ll
@@ -3525,10 +3525,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3591,10 +3591,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 1>
@@ -3735,10 +3735,10 @@ define void @wrappingindvars1(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END2:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 1, i32 2, i32 3>
@@ -3909,11 +3909,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; IND-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; IND-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; IND:       vector.ph:
-; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -2
+; IND-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 510
 ; IND-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; IND-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; IND-NEXT:    [[EXT_MUL5:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; IND-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL5]], 2
+; IND-NEXT:    [[EXT_MUL5:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; IND-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL5]], 2
 ; IND-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; IND-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; IND-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -3978,11 +3978,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; UNROLL-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; UNROLL-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; UNROLL:       vector.ph:
-; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -4
+; UNROLL-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 508
 ; UNROLL-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; UNROLL-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; UNROLL-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; UNROLL-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; UNROLL-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; UNROLL-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <2 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; UNROLL-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <2 x i32> [[DOTSPLATINSERT]], <2 x i32> poison, <2 x i32> zeroinitializer
 ; UNROLL-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <2 x i32> [[DOTSPLAT]], <i32 0, i32 4>
@@ -4128,11 +4128,11 @@ define void @wrappingindvars2(i8 %t, i32 %len, ptr %A) {
 ; INTERLEAVE-NEXT:    [[TMP9:%.*]] = or i1 [[TMP3]], [[TMP8]]
 ; INTERLEAVE-NEXT:    br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; INTERLEAVE:       vector.ph:
-; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], -8
+; INTERLEAVE-NEXT:    [[N_VEC:%.*]] = and i32 [[TMP0]], 504
 ; INTERLEAVE-NEXT:    [[DOTCAST:%.*]] = trunc i32 [[N_VEC]] to i8
 ; INTERLEAVE-NEXT:    [[IND_END:%.*]] = add i8 [[DOTCAST]], [[T]]
-; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add i32 [[N_VEC]], [[EXT]]
-; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl i32 [[EXT_MUL6]], 2
+; INTERLEAVE-NEXT:    [[EXT_MUL6:%.*]] = add nuw nsw i32 [[N_VEC]], [[EXT]]
+; INTERLEAVE-NEXT:    [[IND_END1:%.*]] = shl nuw nsw i32 [[EXT_MUL6]], 2
 ; INTERLEAVE-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EXT_MUL]], i64 0
 ; INTERLEAVE-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer
 ; INTERLEAVE-NEXT:    [[INDUCTION:%.*]] = add nuw nsw <4 x i32> [[DOTSPLAT]], <i32 0, i32 4, i32 8, i32 12>



More information about the llvm-commits mailing list