[llvm] [ValueTracking] Use SimplifyQuery in some public APIs (NFC) (PR #68290)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 10 02:03:13 PDT 2023


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/68290

>From 500a6c95ff63d0b1d68afe7b64fad4a569748aea Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 5 Oct 2023 10:00:48 +0200
Subject: [PATCH 1/2] [Analysis] Move SimplifyQuery into separate header (NFC)

To allow reusing it between InstructionSimplify and ValueTracking.
---
 .../llvm/Analysis/InstructionSimplify.h       |  98 +--------------
 llvm/include/llvm/Analysis/SimplifyQuery.h    | 118 ++++++++++++++++++
 2 files changed, 119 insertions(+), 97 deletions(-)
 create mode 100644 llvm/include/llvm/Analysis/SimplifyQuery.h

diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 92005da1f4c61e7..c626a6522d01779 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -31,7 +31,7 @@
 #ifndef LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 #define LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 
-#include "llvm/IR/PatternMatch.h"
+#include "llvm/Analysis/SimplifyQuery.h"
 
 namespace llvm {
 
@@ -52,102 +52,6 @@ class TargetLibraryInfo;
 class Type;
 class Value;
 
-/// InstrInfoQuery provides an interface to query additional information for
-/// instructions like metadata or keywords like nsw, which provides conservative
-/// results if the users specified it is safe to use.
-struct InstrInfoQuery {
-  InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {}
-  InstrInfoQuery() = default;
-  bool UseInstrInfo = true;
-
-  MDNode *getMetadata(const Instruction *I, unsigned KindID) const {
-    if (UseInstrInfo)
-      return I->getMetadata(KindID);
-    return nullptr;
-  }
-
-  template <class InstT> bool hasNoUnsignedWrap(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoUnsignedWrap();
-    return false;
-  }
-
-  template <class InstT> bool hasNoSignedWrap(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoSignedWrap();
-    return false;
-  }
-
-  bool isExact(const BinaryOperator *Op) const {
-    if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
-      return cast<PossiblyExactOperator>(Op)->isExact();
-    return false;
-  }
-
-  template <class InstT> bool hasNoSignedZeros(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoSignedZeros();
-    return false;
-  }
-};
-
-struct SimplifyQuery {
-  const DataLayout &DL;
-  const TargetLibraryInfo *TLI = nullptr;
-  const DominatorTree *DT = nullptr;
-  AssumptionCache *AC = nullptr;
-  const Instruction *CxtI = nullptr;
-
-  // Wrapper to query additional information for instructions like metadata or
-  // keywords like nsw, which provides conservative results if those cannot
-  // be safely used.
-  const InstrInfoQuery IIQ;
-
-  /// Controls whether simplifications are allowed to constrain the range of
-  /// possible values for uses of undef. If it is false, simplifications are not
-  /// allowed to assume a particular value for a use of undef for example.
-  bool CanUseUndef = true;
-
-  SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr)
-      : DL(DL), CxtI(CXTI) {}
-
-  SimplifyQuery(const DataLayout &DL, const TargetLibraryInfo *TLI,
-                const DominatorTree *DT = nullptr,
-                AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
-                bool CanUseUndef = true)
-      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
-        CanUseUndef(CanUseUndef) {}
-
-  SimplifyQuery(const DataLayout &DL, const DominatorTree *DT,
-                AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
-                bool CanUseUndef = true)
-      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
-        CanUseUndef(CanUseUndef) {}
-
-  SimplifyQuery getWithInstruction(const Instruction *I) const {
-    SimplifyQuery Copy(*this);
-    Copy.CxtI = I;
-    return Copy;
-  }
-  SimplifyQuery getWithoutUndef() const {
-    SimplifyQuery Copy(*this);
-    Copy.CanUseUndef = false;
-    return Copy;
-  }
-
-  /// If CanUseUndef is true, returns whether \p V is undef.
-  /// Otherwise always return false.
-  bool isUndefValue(Value *V) const {
-    if (!CanUseUndef)
-      return false;
-
-    using namespace PatternMatch;
-    return match(V, m_Undef());
-  }
-};
-
 // NOTE: the explicit multiple argument versions of these functions are
 // deprecated.
 // Please use the SimplifyQuery versions in new code.
diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
new file mode 100644
index 000000000000000..f9cc3029221d679
--- /dev/null
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -0,0 +1,118 @@
+//===-- SimplifyQuery.h - Context for simplifications -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
+#define LLVM_ANALYSIS_SIMPLIFYQUERY_H
+
+#include "llvm/IR/PatternMatch.h"
+
+namespace llvm {
+
+class AssumptionCache;
+class DominatorTree;
+class TargetLibraryInfo;
+
+/// InstrInfoQuery provides an interface to query additional information for
+/// instructions like metadata or keywords like nsw, which provides conservative
+/// results if the users specified it is safe to use.
+struct InstrInfoQuery {
+  InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {}
+  InstrInfoQuery() = default;
+  bool UseInstrInfo = true;
+
+  MDNode *getMetadata(const Instruction *I, unsigned KindID) const {
+    if (UseInstrInfo)
+      return I->getMetadata(KindID);
+    return nullptr;
+  }
+
+  template <class InstT> bool hasNoUnsignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoUnsignedWrap();
+    return false;
+  }
+
+  template <class InstT> bool hasNoSignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoSignedWrap();
+    return false;
+  }
+
+  bool isExact(const BinaryOperator *Op) const {
+    if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
+      return cast<PossiblyExactOperator>(Op)->isExact();
+    return false;
+  }
+
+  template <class InstT> bool hasNoSignedZeros(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoSignedZeros();
+    return false;
+  }
+};
+
+struct SimplifyQuery {
+  const DataLayout &DL;
+  const TargetLibraryInfo *TLI = nullptr;
+  const DominatorTree *DT = nullptr;
+  AssumptionCache *AC = nullptr;
+  const Instruction *CxtI = nullptr;
+
+  // Wrapper to query additional information for instructions like metadata or
+  // keywords like nsw, which provides conservative results if those cannot
+  // be safely used.
+  const InstrInfoQuery IIQ;
+
+  /// Controls whether simplifications are allowed to constrain the range of
+  /// possible values for uses of undef. If it is false, simplifications are not
+  /// allowed to assume a particular value for a use of undef for example.
+  bool CanUseUndef = true;
+
+  SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr)
+      : DL(DL), CxtI(CXTI) {}
+
+  SimplifyQuery(const DataLayout &DL, const TargetLibraryInfo *TLI,
+                const DominatorTree *DT = nullptr,
+                AssumptionCache *AC = nullptr,
+                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
+                bool CanUseUndef = true)
+      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
+        CanUseUndef(CanUseUndef) {}
+
+  SimplifyQuery(const DataLayout &DL, const DominatorTree *DT,
+                AssumptionCache *AC = nullptr,
+                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
+                bool CanUseUndef = true)
+      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
+        CanUseUndef(CanUseUndef) {}
+
+  SimplifyQuery getWithInstruction(const Instruction *I) const {
+    SimplifyQuery Copy(*this);
+    Copy.CxtI = I;
+    return Copy;
+  }
+  SimplifyQuery getWithoutUndef() const {
+    SimplifyQuery Copy(*this);
+    Copy.CanUseUndef = false;
+    return Copy;
+  }
+
+  /// If CanUseUndef is true, returns whether \p V is undef.
+  /// Otherwise always return false.
+  bool isUndefValue(Value *V) const {
+    if (!CanUseUndef)
+      return false;
+
+    using namespace PatternMatch;
+    return match(V, m_Undef());
+  }
+};
+
+} // end namespace llvm
+
+#endif

>From 1b3cc4e715ad144fc93c4098fee21b18674926f0 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 5 Oct 2023 10:04:41 +0200
Subject: [PATCH 2/2] [ValueTracking] Use SimplifyQuery for the overflow APIs
 (NFC)

Accept a SimplifyQuery instead of an unpacked list of arguments.
---
 llvm/include/llvm/Analysis/ValueTracking.h    |  40 ++----
 .../Transforms/InstCombine/InstCombiner.h     |  18 ++-
 llvm/lib/Analysis/ValueTracking.cpp           | 132 +++++++-----------
 llvm/lib/Transforms/Scalar/LICM.cpp           |  14 +-
 llvm/lib/Transforms/Scalar/LoopFlatten.cpp    |   5 +-
 .../lib/Transforms/Scalar/NaryReassociate.cpp |   4 +-
 6 files changed, 86 insertions(+), 127 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index a2f5f23d94ee812..5778df32ef1b8cd 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/SimplifyQuery.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/FMF.h"
@@ -39,7 +40,6 @@ struct KnownBits;
 class Loop;
 class LoopInfo;
 class MDNode;
-struct SimplifyQuery;
 class StringRef;
 class TargetLibraryInfo;
 class Value;
@@ -851,44 +851,20 @@ enum class OverflowResult {
 };
 
 OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT,
-                                             bool UseInstrInfo = true);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC,
-                                           const Instruction *CxtI,
-                                           const DominatorTree *DT,
-                                           bool UseInstrInfo = true);
+                                           const SimplifyQuery &SQ);
 OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT,
-                                             bool UseInstrInfo = true);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC = nullptr,
-                                           const Instruction *CxtI = nullptr,
-                                           const DominatorTree *DT = nullptr);
+                                           const SimplifyQuery &SQ);
 /// This version also leverages the sign bit of Add if known.
 OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC = nullptr,
-                                           const Instruction *CxtI = nullptr,
-                                           const DominatorTree *DT = nullptr);
+                                           const SimplifyQuery &SQ);
 OverflowResult computeOverflowForUnsignedSub(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC,
-                                           const Instruction *CxtI,
-                                           const DominatorTree *DT);
+                                           const SimplifyQuery &SQ);
 
 /// Returns true if the arithmetic part of the \p WO 's result is
 /// used only along the paths control dependent on the computation
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index c1cea5649f769ab..dcfcc8f41dd58d0 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -500,34 +500,40 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedMul(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedMul(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedAdd(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForUnsignedSub(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedSub(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedSub(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   virtual bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3af5a6d9a167de4..ce7f9a5ade8ff12 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6245,37 +6245,30 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
 }
 
 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
-static ConstantRange computeConstantRangeIncludingKnownBits(
-    const Value *V, bool ForSigned, const DataLayout &DL, AssumptionCache *AC,
-    const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo = true) {
-  KnownBits Known =
-      computeKnownBits(V, DL, /*Depth=*/0, AC, CxtI, DT, UseInstrInfo);
+static ConstantRange
+computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
+                                       const SimplifyQuery &SQ) {
+  KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
   ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
-  ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo);
+  ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
   ConstantRange::PreferredRangeType RangeType =
       ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
   return CR1.intersectWith(CR2, RangeType);
 }
 
-OverflowResult llvm::computeOverflowForUnsignedMul(
-    const Value *LHS, const Value *RHS, const DataLayout &DL,
-    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo) {
-  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                        UseInstrInfo);
-  KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                        UseInstrInfo);
+OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
+                                                   const Value *RHS,
+                                                   const SimplifyQuery &SQ) {
+  KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
+  KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
 }
 
-OverflowResult
-llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
-                                  const DataLayout &DL, AssumptionCache *AC,
-                                  const Instruction *CxtI,
-                                  const DominatorTree *DT, bool UseInstrInfo) {
+OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
+                                                 const Value *RHS,
+                                                 const SimplifyQuery &SQ) {
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -6286,8 +6279,8 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
 
   // Note that underestimating the number of sign bits gives a more
   // conservative answer.
-  unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) +
-                      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT);
+  unsigned SignBits =
+      ::ComputeNumSignBits(LHS, 0, SQ) + ::ComputeNumSignBits(RHS, 0, SQ);
 
   // First handle the easy case: if we have enough sign bits there's
   // definitely no overflow.
@@ -6304,34 +6297,28 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
     // product is exactly the minimum negative number.
     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
     // For simplicity we just check if at least one side is not negative.
-    KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                          UseInstrInfo);
-    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                          UseInstrInfo);
+    KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
+    KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
       return OverflowResult::NeverOverflows;
   }
   return OverflowResult::MayOverflow;
 }
 
-OverflowResult llvm::computeOverflowForUnsignedAdd(
-    const Value *LHS, const Value *RHS, const DataLayout &DL,
-    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo) {
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
+OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
+                                                   const Value *RHS,
+                                                   const SimplifyQuery &SQ) {
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
   return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
 }
 
 static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
                                                   const Value *RHS,
                                                   const AddOperator *Add,
-                                                  const DataLayout &DL,
-                                                  AssumptionCache *AC,
-                                                  const Instruction *CxtI,
-                                                  const DominatorTree *DT) {
+                                                  const SimplifyQuery &SQ) {
   if (Add && Add->hasNoSignedWrap()) {
     return OverflowResult::NeverOverflows;
   }
@@ -6350,14 +6337,14 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
-      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+  if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
+      ::ComputeNumSignBits(RHS, 0, SQ) > 1)
     return OverflowResult::NeverOverflows;
 
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
   OverflowResult OR =
       mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
   if (OR != OverflowResult::MayOverflow)
@@ -6378,8 +6365,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
       (LHSRange.isAllNegative() || RHSRange.isAllNegative());
   if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
     KnownBits AddKnown(LHSRange.getBitWidth());
-    computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0,
-                               SimplifyQuery(DL, DT, AC, CxtI, DT));
+    computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0, SQ);
     if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
         (AddKnown.isNegative() && LHSOrRHSKnownNegative))
       return OverflowResult::NeverOverflows;
@@ -6390,10 +6376,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
 
 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
                                                    const Value *RHS,
-                                                   const DataLayout &DL,
-                                                   AssumptionCache *AC,
-                                                   const Instruction *CxtI,
-                                                   const DominatorTree *DT) {
+                                                   const SimplifyQuery &SQ) {
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -6407,32 +6390,29 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
   //       See simplifyICmpWithBinOpOnLHS() for candidates.
   if (match(RHS, m_URem(m_Specific(LHS), m_Value())) ||
       match(RHS, m_NUWSub(m_Specific(LHS), m_Value())))
-    if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
+    if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
       return OverflowResult::NeverOverflows;
 
   // Checking for conditions implied by dominating conditions may be expensive.
   // Limit it to usub_with_overflow calls for now.
-  if (match(CxtI,
+  if (match(SQ.CxtI,
             m_Intrinsic<Intrinsic::usub_with_overflow>(m_Value(), m_Value())))
-    if (auto C =
-            isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, CxtI, DL)) {
+    if (auto C = isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, SQ.CxtI,
+                                         SQ.DL)) {
       if (*C)
         return OverflowResult::NeverOverflows;
       return OverflowResult::AlwaysOverflowsLow;
     }
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
   return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
 }
 
 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
                                                  const Value *RHS,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
+                                                 const SimplifyQuery &SQ) {
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -6443,19 +6423,19 @@ OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
   // then determining no-overflow may allow other transforms.
   if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) ||
       match(RHS, m_NSWSub(m_Specific(LHS), m_Value())))
-    if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
+    if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
       return OverflowResult::NeverOverflows;
 
   // If LHS and RHS each have at least two sign bits, the subtraction
   // cannot overflow.
-  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
-      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+  if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
+      ::ComputeNumSignBits(RHS, 0, SQ) > 1)
     return OverflowResult::NeverOverflows;
 
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
   return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
 }
 
@@ -6949,21 +6929,15 @@ bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
 }
 
 OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
+                                                 const SimplifyQuery &SQ) {
   return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
-                                       Add, DL, AC, CxtI, DT);
+                                       Add, SQ);
 }
 
 OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
                                                  const Value *RHS,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
-  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT);
+                                                 const SimplifyQuery &SQ) {
+  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
 }
 
 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 4cb70cbdf093b36..224db19caac82ab 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -2540,8 +2540,9 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS,
   // we want to avoid this.
   auto &DL = L.getHeader()->getModule()->getDataLayout();
   bool ProvedNoOverflowAfterReassociate =
-      computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp,
-                                  DT) == llvm::OverflowResult::NeverOverflows;
+      computeOverflowForSignedSub(InvariantRHS, InvariantOp,
+                                  SimplifyQuery(DL, DT, AC, &ICmp)) ==
+      llvm::OverflowResult::NeverOverflows;
   if (!ProvedNoOverflowAfterReassociate)
     return false;
   auto *Preheader = L.getLoopPreheader();
@@ -2591,15 +2592,16 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS,
   // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that
   // "C1 - C2" does not overflow.
   auto &DL = L.getHeader()->getModule()->getDataLayout();
+  SimplifyQuery SQ(DL, DT, AC, &ICmp);
   if (VariantSubtracted) {
     // C1 - LV < C2 --> LV > C1 - C2
-    if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp,
-                                    DT) != llvm::OverflowResult::NeverOverflows)
+    if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, SQ) !=
+        llvm::OverflowResult::NeverOverflows)
       return false;
   } else {
     // LV - C1 < C2 --> LV < C1 + C2
-    if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp,
-                                    DT) != llvm::OverflowResult::NeverOverflows)
+    if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, SQ) !=
+        llvm::OverflowResult::NeverOverflows)
       return false;
   }
   auto *Preheader = L.getLoopPreheader();
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index edc8a4956dd1c71..b1add3c42976fd6 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -641,8 +641,9 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
   // Check if the multiply could not overflow due to known ranges of the
   // input values.
   OverflowResult OR = computeOverflowForUnsignedMul(
-      FI.InnerTripCount, FI.OuterTripCount, DL, AC,
-      FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
+      FI.InnerTripCount, FI.OuterTripCount,
+      SimplifyQuery(DL, DT, AC,
+                    FI.OuterLoop->getLoopPreheader()->getTerminator()));
   if (OR != OverflowResult::MayOverflow)
     return OR;
 
diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
index 9c3e9a2fd018aca..021b03e42403cd7 100644
--- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -372,9 +372,9 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
     // If the I-th index needs sext and the underlying add is not equipped with
     // nsw, we cannot split the add because
     //   sext(LHS + RHS) != sext(LHS) + sext(RHS).
+    SimplifyQuery SQ(*DL, DT, AC, GEP);
     if (requiresSignExtension(IndexToSplit, GEP) &&
-        computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
-            OverflowResult::NeverOverflows)
+        computeOverflowForSignedAdd(AO, SQ) != OverflowResult::NeverOverflows)
       return nullptr;
 
     Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);



More information about the llvm-commits mailing list