[llvm] [Analysis] Unify most of the tracking between AssumptionCache and DomConditionCache (PR #83161)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 27 10:01:08 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

- **[Analysis] Move `DomConditionCache::findAffectedValues` to a new file; NFC**
- **[Analysis] Share `findAffectedValues` between DomConditionCache and AssumptionCache; NFC**
- **[Analysis] Unify most of the tracking between AssumptionCache and DomConditionCache**


---
Full diff: https://github.com/llvm/llvm-project/pull/83161.diff


7 Files Affected:

- (added) llvm/include/llvm/Analysis/ConditionCacheUtil.h (+118) 
- (modified) llvm/lib/Analysis/AssumptionCache.cpp (+13-62) 
- (modified) llvm/lib/Analysis/DomConditionCache.cpp (+3-64) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+2-4) 
- (modified) llvm/test/Analysis/ValueTracking/numsignbits-from-assume.ll (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll (+2-3) 
- (modified) llvm/test/Transforms/InstSimplify/assume_icmp.ll (+4-8) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/ConditionCacheUtil.h b/llvm/include/llvm/Analysis/ConditionCacheUtil.h
new file mode 100644
index 00000000000000..9078ac921eaca6
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ConditionCacheUtil.h
@@ -0,0 +1,118 @@
+//===- llvm/Analysis/ConditionCacheUtil.h -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Shared by DomConditionCache and AssumptionCache. Holds common operation of
+// finding values potentially affected by an assumed/branched on condition.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_CONDITIONCACHEUTIL_H
+#define LLVM_ANALYSIS_CONDITIONCACHEUTIL_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/PatternMatch.h"
+#include <functional>
+
+namespace llvm {
+
+static void addValueAffectedByCondition(
+    Value *V, std::function<void(Value *, int)> InsertAffected, int Idx = -1) {
+  using namespace llvm::PatternMatch;
+  assert(V != nullptr);
+  if (isa<Argument>(V) || isa<GlobalValue>(V)) {
+    InsertAffected(V, Idx);
+  } else if (auto *I = dyn_cast<Instruction>(V)) {
+    InsertAffected(V, Idx);
+
+    // Peek through unary operators to find the source of the condition.
+    Value *Op;
+    if (match(I, m_PtrToInt(m_Value(Op)))) {
+      if (isa<Instruction>(Op) || isa<Argument>(Op))
+        InsertAffected(Op, Idx);
+    }
+  }
+}
+
+static void findValuesAffectedByCondition(
+    Value *Cond, bool IsAssume,
+    std::function<void(Value *, int)> InsertAffected) {
+  using namespace llvm::PatternMatch;
+  auto AddAffected = [&InsertAffected](Value *V) {
+    addValueAffectedByCondition(V, InsertAffected);
+  };
+
+  SmallVector<Value *, 8> Worklist;
+  SmallPtrSet<Value *, 8> Visited;
+  Worklist.push_back(Cond);
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+    if (!Visited.insert(V).second)
+      continue;
+
+    CmpInst::Predicate Pred;
+    Value *A, *B, *X;
+
+    if (IsAssume) {
+      AddAffected(V);
+      if (match(V, m_Not(m_Value(X))))
+        AddAffected(X);
+    }
+
+    if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
+      Worklist.push_back(A);
+      Worklist.push_back(B);
+    } else if (match(V, m_Cmp(Pred, m_Value(A), m_Value(B)))) {
+      AddAffected(A);
+      if (IsAssume)
+        AddAffected(B);
+
+      if (ICmpInst::isEquality(Pred)) {
+        if (match(B, m_ConstantInt())) {
+          // (X & C) or (X | C) or (X ^ C).
+          // (X << C) or (X >>_s C) or (X >>_u C).
+          if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+              match(A, m_Shift(m_Value(X), m_ConstantInt())))
+            AddAffected(X);
+        }
+      } else {
+        // Handle (A + C1) u< C2, which is the canonical form of
+        // A > C3 && A < C4.
+        if (match(A, m_Add(m_Value(X), m_ConstantInt())) &&
+            match(B, m_ConstantInt()))
+          AddAffected(X);
+
+        // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
+        // by computeKnownFPClass().
+        if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
+          if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
+            InsertAffected(X, -1);
+          else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
+            InsertAffected(X, -1);
+        }
+
+        if (CmpInst::isFPPredicate(Pred)) {
+          // fcmp fneg(x), y
+          // fcmp fabs(x), y
+          // fcmp fneg(fabs(x)), y
+          if (match(A, m_FNeg(m_Value(A))))
+            AddAffected(A);
+          if (match(A, m_FAbs(m_Value(A))))
+            AddAffected(A);
+        }
+      }
+    } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
+                                                              m_Value()))) {
+      // Handle patterns that computeKnownFPClass() support.
+      AddAffected(A);
+    }
+  }
+}
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 1b7277df0e0cd0..4c0e0db7f9d19e 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Analysis/AssumeBundleQueries.h"
+#include "llvm/Analysis/ConditionCacheUtil.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/BasicBlock.h"
@@ -61,20 +62,8 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
   // Note: This code must be kept in-sync with the code in
   // computeKnownBitsFromAssume in ValueTracking.
 
-  auto AddAffected = [&Affected](Value *V, unsigned Idx =
-                                               AssumptionCache::ExprResultIdx) {
-    if (isa<Argument>(V) || isa<GlobalValue>(V)) {
-      Affected.push_back({V, Idx});
-    } else if (auto *I = dyn_cast<Instruction>(V)) {
-      Affected.push_back({I, Idx});
-
-      // Peek through unary operators to find the source of the condition.
-      Value *Op;
-      if (match(I, m_PtrToInt(m_Value(Op)))) {
-        if (isa<Instruction>(Op) || isa<Argument>(Op))
-          Affected.push_back({Op, Idx});
-      }
-    }
+  auto InsertAffected = [&Affected](Value *V, int Idx) {
+    Affected.push_back({V, Idx < 0 ? AssumptionCache::ExprResultIdx : Idx});
   };
 
   for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
@@ -82,64 +71,26 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
     if (Bundle.getTagName() == "separate_storage") {
       assert(Bundle.Inputs.size() == 2 &&
              "separate_storage must have two args");
-      AddAffected(getUnderlyingObject(Bundle.Inputs[0]), Idx);
-      AddAffected(getUnderlyingObject(Bundle.Inputs[1]), Idx);
+      addValueAffectedByCondition(getUnderlyingObject(Bundle.Inputs[0]),
+                                  InsertAffected, Idx);
+      addValueAffectedByCondition(getUnderlyingObject(Bundle.Inputs[1]),
+                                  InsertAffected, Idx);
     } else if (Bundle.Inputs.size() > ABA_WasOn &&
                Bundle.getTagName() != IgnoreBundleTag)
-      AddAffected(Bundle.Inputs[ABA_WasOn], Idx);
+      addValueAffectedByCondition(Bundle.Inputs[ABA_WasOn], InsertAffected,
+                                  Idx);
   }
 
-  Value *Cond = CI->getArgOperand(0), *A, *B;
-  AddAffected(Cond);
-  if (match(Cond, m_Not(m_Value(A))))
-    AddAffected(A);
-
-  CmpInst::Predicate Pred;
-  if (match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) {
-    AddAffected(A);
-    AddAffected(B);
-
-    if (Pred == ICmpInst::ICMP_EQ) {
-      if (match(B, m_ConstantInt())) {
-        Value *X;
-        // (X & C) or (X | C) or (X ^ C).
-        // (X << C) or (X >>_s C) or (X >>_u C).
-        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-            match(A, m_Shift(m_Value(X), m_ConstantInt())))
-          AddAffected(X);
-      }
-    } else if (Pred == ICmpInst::ICMP_NE) {
-      Value *X;
-      // Handle (X & pow2 != 0).
-      if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
-        AddAffected(X);
-    } else if (Pred == ICmpInst::ICMP_ULT) {
-      Value *X;
-      // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
-      // and recognized by LVI at least.
-      if (match(A, m_Add(m_Value(X), m_ConstantInt())) &&
-          match(B, m_ConstantInt()))
-        AddAffected(X);
-    } else if (CmpInst::isFPPredicate(Pred)) {
-      // fcmp fneg(x), y
-      // fcmp fabs(x), y
-      // fcmp fneg(fabs(x)), y
-      if (match(A, m_FNeg(m_Value(A))))
-        AddAffected(A);
-      if (match(A, m_FAbs(m_Value(A))))
-        AddAffected(A);
-    }
-  } else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
-                                                            m_Value(B)))) {
-    AddAffected(A);
-  }
+  Value *Cond = CI->getArgOperand(0);
+  findValuesAffectedByCondition(Cond, /*IsAssume*/ true, InsertAffected);
 
   if (TTI) {
     const Value *Ptr;
     unsigned AS;
     std::tie(Ptr, AS) = TTI->getPredicatedAddrSpace(Cond);
     if (Ptr)
-      AddAffected(const_cast<Value *>(Ptr->stripInBoundsOffsets()));
+      addValueAffectedByCondition(
+          const_cast<Value *>(Ptr->stripInBoundsOffsets()), InsertAffected);
   }
 }
 
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index da05e02b4b57f7..fadd0a6f22953b 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -7,75 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/DomConditionCache.h"
-#include "llvm/IR/PatternMatch.h"
+#include "llvm/Analysis/ConditionCacheUtil.h"
 
 using namespace llvm;
-using namespace llvm::PatternMatch;
 
-// TODO: This code is very similar to findAffectedValues() in
-// AssumptionCache, but currently specialized to just the patterns that
-// computeKnownBits() supports, and without the notion of result elem indices
-// that are AC specific. Deduplicate this code once we have a clearer picture
-// of how much they can be shared.
 static void findAffectedValues(Value *Cond,
                                SmallVectorImpl<Value *> &Affected) {
-  auto AddAffected = [&Affected](Value *V) {
-    if (isa<Argument>(V) || isa<GlobalValue>(V)) {
-      Affected.push_back(V);
-    } else if (auto *I = dyn_cast<Instruction>(V)) {
-      Affected.push_back(I);
-
-      // Peek through unary operators to find the source of the condition.
-      Value *Op;
-      if (match(I, m_PtrToInt(m_Value(Op)))) {
-        if (isa<Instruction>(Op) || isa<Argument>(Op))
-          Affected.push_back(Op);
-      }
-    }
-  };
-
-  SmallVector<Value *, 8> Worklist;
-  SmallPtrSet<Value *, 8> Visited;
-  Worklist.push_back(Cond);
-  while (!Worklist.empty()) {
-    Value *V = Worklist.pop_back_val();
-    if (!Visited.insert(V).second)
-      continue;
-
-    CmpInst::Predicate Pred;
-    Value *A, *B;
-    if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
-      Worklist.push_back(A);
-      Worklist.push_back(B);
-    } else if (match(V, m_ICmp(Pred, m_Value(A), m_Constant()))) {
-      AddAffected(A);
-
-      if (ICmpInst::isEquality(Pred)) {
-        Value *X;
-        // (X & C) or (X | C) or (X ^ C).
-        // (X << C) or (X >>_s C) or (X >>_u C).
-        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
-            match(A, m_Shift(m_Value(X), m_ConstantInt())))
-          AddAffected(X);
-      } else {
-        Value *X;
-        // Handle (A + C1) u< C2, which is the canonical form of
-        // A > C3 && A < C4.
-        if (match(A, m_Add(m_Value(X), m_ConstantInt())))
-          AddAffected(X);
-        // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported by
-        // computeKnownFPClass().
-        if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT) &&
-            match(A, m_ElementWiseBitCast(m_Value(X))))
-          Affected.push_back(X);
-      }
-    } else if (match(Cond, m_CombineOr(m_FCmp(Pred, m_Value(A), m_Constant()),
-                                       m_Intrinsic<Intrinsic::is_fpclass>(
-                                           m_Value(A), m_Constant())))) {
-      // Handle patterns that computeKnownFPClass() support.
-      AddAffected(A);
-    }
-  }
+  auto InsertAffected = [&Affected](Value *V, int) { Affected.push_back(V); };
+  findValuesAffectedByCondition(Cond, /*IsAssume*/ false, InsertAffected);
 }
 
 void DomConditionCache::registerBranch(BranchInst *BI) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e591ac504e9f05..39d358b45d9021 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -806,15 +806,13 @@ void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
     if (Depth == MaxAnalysisRecursionDepth)
       continue;
 
-    ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg);
-    if (!Cmp)
+    if (!isa<ICmpInst>(Arg) && !match(Arg, m_LogicalOp(m_Value(), m_Value())))
       continue;
 
     if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
       continue;
 
-    computeKnownBitsFromCmp(V, Cmp->getPredicate(), Cmp->getOperand(0),
-                            Cmp->getOperand(1), Known, Q);
+    computeKnownBitsFromCond(V, Arg, Known, /*Depth*/ 0, Q, /*Invert*/ false);
   }
 
   // Conflicting assumption: Undefined behavior will occur on this execution
diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-from-assume.ll b/llvm/test/Analysis/ValueTracking/numsignbits-from-assume.ll
index 00c66eeb599572..95ac98532da621 100644
--- a/llvm/test/Analysis/ValueTracking/numsignbits-from-assume.ll
+++ b/llvm/test/Analysis/ValueTracking/numsignbits-from-assume.ll
@@ -51,7 +51,7 @@ define i32 @computeNumSignBits_sub1(i32 %in) {
 
 define i32 @computeNumSignBits_sub2(i32 %in) {
 ; CHECK-LABEL: @computeNumSignBits_sub2(
-; CHECK-NEXT:    [[SUB:%.*]] = add i32 [[IN:%.*]], -1
+; CHECK-NEXT:    [[SUB:%.*]] = add nsw i32 [[IN:%.*]], -1
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[SUB]], 43
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[COND]])
 ; CHECK-NEXT:    [[SH:%.*]] = shl nuw nsw i32 [[SUB]], 3
diff --git a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
index d40cd7fd503ecc..d6706d76056eea 100644
--- a/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
+++ b/llvm/test/Transforms/InstCombine/fpclass-from-dom-cond.ll
@@ -185,10 +185,9 @@ define i1 @test8(float %x) {
 ; CHECK-NEXT:    [[COND:%.*]] = fcmp oeq float [[ABS]], 0x7FF0000000000000
 ; CHECK-NEXT:    br i1 [[COND]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
 ; CHECK:       if.then:
-; CHECK-NEXT:    [[RET1:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X]], i32 575)
-; CHECK-NEXT:    ret i1 [[RET1]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       if.else:
-; CHECK-NEXT:    [[RET2:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X]], i32 575)
+; CHECK-NEXT:    [[RET2:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X]], i32 59)
 ; CHECK-NEXT:    ret i1 [[RET2]]
 ;
   %abs = call float @llvm.fabs.f32(float %x)
diff --git a/llvm/test/Transforms/InstSimplify/assume_icmp.ll b/llvm/test/Transforms/InstSimplify/assume_icmp.ll
index 9ac3f468ab604f..2ebb3b23a9b883 100644
--- a/llvm/test/Transforms/InstSimplify/assume_icmp.ll
+++ b/llvm/test/Transforms/InstSimplify/assume_icmp.ll
@@ -93,14 +93,10 @@ define void @and(i32 %x, i32 %y, i32 %z) {
 ; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[Z:%.*]], [[Y]]
 ; CHECK-NEXT:    [[AND:%.*]] = and i1 [[CMP1]], [[CMP2]]
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[AND]])
-; CHECK-NEXT:    [[CMP3:%.*]] = icmp ugt i32 [[X]], [[Y]]
-; CHECK-NEXT:    call void @use(i1 [[CMP3]])
-; CHECK-NEXT:    [[CMP4:%.*]] = icmp uge i32 [[X]], [[Y]]
-; CHECK-NEXT:    call void @use(i1 [[CMP4]])
-; CHECK-NEXT:    [[CMP5:%.*]] = icmp ugt i32 [[Z]], [[Y]]
-; CHECK-NEXT:    call void @use(i1 [[CMP5]])
-; CHECK-NEXT:    [[CMP6:%.*]] = icmp uge i32 [[Z]], [[Y]]
-; CHECK-NEXT:    call void @use(i1 [[CMP6]])
+; CHECK-NEXT:    call void @use(i1 true)
+; CHECK-NEXT:    call void @use(i1 true)
+; CHECK-NEXT:    call void @use(i1 true)
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    ret void
 ;
   %cmp1 = icmp ugt i32 %x, %y

``````````

</details>


https://github.com/llvm/llvm-project/pull/83161


More information about the llvm-commits mailing list