[llvm] [ValueTracking] Use SimplifyQuery in some public APIs (NFC) (PR #68290)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 5 01:36:39 PDT 2023


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/68290

This patch moves SimplifyQuery into a separate header, so that both InstSimplify and ValueTracking can use it in their public API. ValueTracking already (mostly) uses SimplifyQuery for its internal recursive calls, but the public API accepts unpacked arguments.

I'd like to move things towards accepting SimplifyQuery in the public API as well. This patch in particular migrates the computeOverflowXYZ() APIs.

>From 81adf52b953a7166ae44342a84bdf7a1e2cbe3ba Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 5 Oct 2023 10:00:48 +0200
Subject: [PATCH 1/2] [Analysis] Move SimplifyQuery into separate header (NFC)

To allow reusing it between InstructionSimplify and ValueTracking.
---
 .../llvm/Analysis/InstructionSimplify.h       |  98 +--------------
 llvm/include/llvm/Analysis/SimplifyQuery.h    | 118 ++++++++++++++++++
 2 files changed, 119 insertions(+), 97 deletions(-)
 create mode 100644 llvm/include/llvm/Analysis/SimplifyQuery.h

diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 92005da1f4c61e7..c626a6522d01779 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -31,7 +31,7 @@
 #ifndef LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 #define LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 
-#include "llvm/IR/PatternMatch.h"
+#include "llvm/Analysis/SimplifyQuery.h"
 
 namespace llvm {
 
@@ -52,102 +52,6 @@ class TargetLibraryInfo;
 class Type;
 class Value;
 
-/// InstrInfoQuery provides an interface to query additional information for
-/// instructions like metadata or keywords like nsw, which provides conservative
-/// results if the users specified it is safe to use.
-struct InstrInfoQuery {
-  InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {}
-  InstrInfoQuery() = default;
-  bool UseInstrInfo = true;
-
-  MDNode *getMetadata(const Instruction *I, unsigned KindID) const {
-    if (UseInstrInfo)
-      return I->getMetadata(KindID);
-    return nullptr;
-  }
-
-  template <class InstT> bool hasNoUnsignedWrap(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoUnsignedWrap();
-    return false;
-  }
-
-  template <class InstT> bool hasNoSignedWrap(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoSignedWrap();
-    return false;
-  }
-
-  bool isExact(const BinaryOperator *Op) const {
-    if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
-      return cast<PossiblyExactOperator>(Op)->isExact();
-    return false;
-  }
-
-  template <class InstT> bool hasNoSignedZeros(const InstT *Op) const {
-    if (UseInstrInfo)
-      return Op->hasNoSignedZeros();
-    return false;
-  }
-};
-
-struct SimplifyQuery {
-  const DataLayout &DL;
-  const TargetLibraryInfo *TLI = nullptr;
-  const DominatorTree *DT = nullptr;
-  AssumptionCache *AC = nullptr;
-  const Instruction *CxtI = nullptr;
-
-  // Wrapper to query additional information for instructions like metadata or
-  // keywords like nsw, which provides conservative results if those cannot
-  // be safely used.
-  const InstrInfoQuery IIQ;
-
-  /// Controls whether simplifications are allowed to constrain the range of
-  /// possible values for uses of undef. If it is false, simplifications are not
-  /// allowed to assume a particular value for a use of undef for example.
-  bool CanUseUndef = true;
-
-  SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr)
-      : DL(DL), CxtI(CXTI) {}
-
-  SimplifyQuery(const DataLayout &DL, const TargetLibraryInfo *TLI,
-                const DominatorTree *DT = nullptr,
-                AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
-                bool CanUseUndef = true)
-      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
-        CanUseUndef(CanUseUndef) {}
-
-  SimplifyQuery(const DataLayout &DL, const DominatorTree *DT,
-                AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
-                bool CanUseUndef = true)
-      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
-        CanUseUndef(CanUseUndef) {}
-
-  SimplifyQuery getWithInstruction(const Instruction *I) const {
-    SimplifyQuery Copy(*this);
-    Copy.CxtI = I;
-    return Copy;
-  }
-  SimplifyQuery getWithoutUndef() const {
-    SimplifyQuery Copy(*this);
-    Copy.CanUseUndef = false;
-    return Copy;
-  }
-
-  /// If CanUseUndef is true, returns whether \p V is undef.
-  /// Otherwise always return false.
-  bool isUndefValue(Value *V) const {
-    if (!CanUseUndef)
-      return false;
-
-    using namespace PatternMatch;
-    return match(V, m_Undef());
-  }
-};
-
 // NOTE: the explicit multiple argument versions of these functions are
 // deprecated.
 // Please use the SimplifyQuery versions in new code.
diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
new file mode 100644
index 000000000000000..f9cc3029221d679
--- /dev/null
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -0,0 +1,118 @@
+//===-- SimplifyQuery.h - Context for simplifications -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_SIMPLIFYQUERY_H
+#define LLVM_ANALYSIS_SIMPLIFYQUERY_H
+
+#include "llvm/IR/PatternMatch.h"
+
+namespace llvm {
+
+class AssumptionCache;
+class DominatorTree;
+class TargetLibraryInfo;
+
+/// InstrInfoQuery provides an interface to query additional information for
+/// instructions like metadata or keywords like nsw, which provides conservative
+/// results if the users specified it is safe to use.
+struct InstrInfoQuery {
+  InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {}
+  InstrInfoQuery() = default;
+  bool UseInstrInfo = true;
+
+  MDNode *getMetadata(const Instruction *I, unsigned KindID) const {
+    if (UseInstrInfo)
+      return I->getMetadata(KindID);
+    return nullptr;
+  }
+
+  template <class InstT> bool hasNoUnsignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoUnsignedWrap();
+    return false;
+  }
+
+  template <class InstT> bool hasNoSignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoSignedWrap();
+    return false;
+  }
+
+  bool isExact(const BinaryOperator *Op) const {
+    if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
+      return cast<PossiblyExactOperator>(Op)->isExact();
+    return false;
+  }
+
+  template <class InstT> bool hasNoSignedZeros(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoSignedZeros();
+    return false;
+  }
+};
+
+struct SimplifyQuery {
+  const DataLayout &DL;
+  const TargetLibraryInfo *TLI = nullptr;
+  const DominatorTree *DT = nullptr;
+  AssumptionCache *AC = nullptr;
+  const Instruction *CxtI = nullptr;
+
+  // Wrapper to query additional information for instructions like metadata or
+  // keywords like nsw, which provides conservative results if those cannot
+  // be safely used.
+  const InstrInfoQuery IIQ;
+
+  /// Controls whether simplifications are allowed to constrain the range of
+  /// possible values for uses of undef. If it is false, simplifications are not
+  /// allowed to assume a particular value for a use of undef for example.
+  bool CanUseUndef = true;
+
+  SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr)
+      : DL(DL), CxtI(CXTI) {}
+
+  SimplifyQuery(const DataLayout &DL, const TargetLibraryInfo *TLI,
+                const DominatorTree *DT = nullptr,
+                AssumptionCache *AC = nullptr,
+                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
+                bool CanUseUndef = true)
+      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
+        CanUseUndef(CanUseUndef) {}
+
+  SimplifyQuery(const DataLayout &DL, const DominatorTree *DT,
+                AssumptionCache *AC = nullptr,
+                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
+                bool CanUseUndef = true)
+      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
+        CanUseUndef(CanUseUndef) {}
+
+  SimplifyQuery getWithInstruction(const Instruction *I) const {
+    SimplifyQuery Copy(*this);
+    Copy.CxtI = I;
+    return Copy;
+  }
+  SimplifyQuery getWithoutUndef() const {
+    SimplifyQuery Copy(*this);
+    Copy.CanUseUndef = false;
+    return Copy;
+  }
+
+  /// If CanUseUndef is true, returns whether \p V is undef.
+  /// Otherwise always return false.
+  bool isUndefValue(Value *V) const {
+    if (!CanUseUndef)
+      return false;
+
+    using namespace PatternMatch;
+    return match(V, m_Undef());
+  }
+};
+
+} // end namespace llvm
+
+#endif

>From bd1de448e9f9c40b45afd4d4beff750e59fcec60 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 5 Oct 2023 10:04:41 +0200
Subject: [PATCH 2/2] [ValueTracking] Use SimplifyQuery for the overflow APIs
 (NFC)

Accept a SimplifyQuery instead of an unpacked list of arguments.
---
 llvm/include/llvm/Analysis/ValueTracking.h    |  40 ++----
 .../Transforms/InstCombine/InstCombiner.h     |  18 ++-
 llvm/lib/Analysis/ValueTracking.cpp           | 132 +++++++-----------
 llvm/lib/Transforms/Scalar/LICM.cpp           |  14 +-
 llvm/lib/Transforms/Scalar/LoopFlatten.cpp    |   5 +-
 .../lib/Transforms/Scalar/NaryReassociate.cpp |   4 +-
 6 files changed, 86 insertions(+), 127 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index a2f5f23d94ee812..5778df32ef1b8cd 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/SimplifyQuery.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/FMF.h"
@@ -39,7 +40,6 @@ struct KnownBits;
 class Loop;
 class LoopInfo;
 class MDNode;
-struct SimplifyQuery;
 class StringRef;
 class TargetLibraryInfo;
 class Value;
@@ -851,44 +851,20 @@ enum class OverflowResult {
 };
 
 OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT,
-                                             bool UseInstrInfo = true);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC,
-                                           const Instruction *CxtI,
-                                           const DominatorTree *DT,
-                                           bool UseInstrInfo = true);
+                                           const SimplifyQuery &SQ);
 OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT,
-                                             bool UseInstrInfo = true);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC = nullptr,
-                                           const Instruction *CxtI = nullptr,
-                                           const DominatorTree *DT = nullptr);
+                                           const SimplifyQuery &SQ);
 /// This version also leverages the sign bit of Add if known.
 OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC = nullptr,
-                                           const Instruction *CxtI = nullptr,
-                                           const DominatorTree *DT = nullptr);
+                                           const SimplifyQuery &SQ);
 OverflowResult computeOverflowForUnsignedSub(const Value *LHS, const Value *RHS,
-                                             const DataLayout &DL,
-                                             AssumptionCache *AC,
-                                             const Instruction *CxtI,
-                                             const DominatorTree *DT);
+                                             const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
-                                           const DataLayout &DL,
-                                           AssumptionCache *AC,
-                                           const Instruction *CxtI,
-                                           const DominatorTree *DT);
+                                           const SimplifyQuery &SQ);
 
 /// Returns true if the arithmetic part of the \p WO 's result is
 /// used only along the paths control dependent on the computation
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index c1cea5649f769ab..dcfcc8f41dd58d0 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -500,34 +500,40 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedMul(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedMul(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedAdd(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForUnsignedSub(const Value *LHS,
                                                const Value *RHS,
                                                const Instruction *CxtI) const {
-    return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForUnsignedSub(LHS, RHS,
+                                               SQ.getWithInstruction(CxtI));
   }
 
   OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
                                              const Instruction *CxtI) const {
-    return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
+    return llvm::computeOverflowForSignedSub(LHS, RHS,
+                                             SQ.getWithInstruction(CxtI));
   }
 
   virtual bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index b76becf24d10fc9..ad55698c7c19025 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -6245,37 +6245,30 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
 }
 
 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
-static ConstantRange computeConstantRangeIncludingKnownBits(
-    const Value *V, bool ForSigned, const DataLayout &DL, AssumptionCache *AC,
-    const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo = true) {
-  KnownBits Known =
-      computeKnownBits(V, DL, /*Depth=*/0, AC, CxtI, DT, UseInstrInfo);
+static ConstantRange
+computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
+                                       const SimplifyQuery &SQ) {
+  KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
   ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
-  ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo);
+  ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
   ConstantRange::PreferredRangeType RangeType =
       ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
   return CR1.intersectWith(CR2, RangeType);
 }
 
-OverflowResult llvm::computeOverflowForUnsignedMul(
-    const Value *LHS, const Value *RHS, const DataLayout &DL,
-    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo) {
-  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                        UseInstrInfo);
-  KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                        UseInstrInfo);
+OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
+                                                   const Value *RHS,
+                                                   const SimplifyQuery &SQ) {
+  KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
+  KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
 }
 
-OverflowResult
-llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
-                                  const DataLayout &DL, AssumptionCache *AC,
-                                  const Instruction *CxtI,
-                                  const DominatorTree *DT, bool UseInstrInfo) {
+OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
+                                                 const Value *RHS,
+                                                 const SimplifyQuery &SQ) {
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -6286,8 +6279,8 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
 
   // Note that underestimating the number of sign bits gives a more
   // conservative answer.
-  unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) +
-                      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT);
+  unsigned SignBits =
+      ::ComputeNumSignBits(LHS, 0, SQ) + ::ComputeNumSignBits(RHS, 0, SQ);
 
   // First handle the easy case: if we have enough sign bits there's
   // definitely no overflow.
@@ -6304,34 +6297,28 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
     // product is exactly the minimum negative number.
     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
     // For simplicity we just check if at least one side is not negative.
-    KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                          UseInstrInfo);
-    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
-                                          UseInstrInfo);
+    KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
+    KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
       return OverflowResult::NeverOverflows;
   }
   return OverflowResult::MayOverflow;
 }
 
-OverflowResult llvm::computeOverflowForUnsignedAdd(
-    const Value *LHS, const Value *RHS, const DataLayout &DL,
-    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
-    bool UseInstrInfo) {
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
+OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
+                                                   const Value *RHS,
+                                                   const SimplifyQuery &SQ) {
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
   return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
 }
 
 static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
                                                   const Value *RHS,
                                                   const AddOperator *Add,
-                                                  const DataLayout &DL,
-                                                  AssumptionCache *AC,
-                                                  const Instruction *CxtI,
-                                                  const DominatorTree *DT) {
+                                                  const SimplifyQuery &SQ) {
   if (Add && Add->hasNoSignedWrap()) {
     return OverflowResult::NeverOverflows;
   }
@@ -6350,14 +6337,14 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
-      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+  if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
+      ::ComputeNumSignBits(RHS, 0, SQ) > 1)
     return OverflowResult::NeverOverflows;
 
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
   OverflowResult OR =
       mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
   if (OR != OverflowResult::MayOverflow)
@@ -6378,8 +6365,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
       (LHSRange.isAllNegative() || RHSRange.isAllNegative());
   if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
     KnownBits AddKnown(LHSRange.getBitWidth());
-    computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0,
-                               SimplifyQuery(DL, DT, AC, CxtI, DT));
+    computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0, SQ);
     if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
         (AddKnown.isNegative() && LHSOrRHSKnownNegative))
       return OverflowResult::NeverOverflows;
@@ -6390,10 +6376,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
 
 OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
                                                    const Value *RHS,
-                                                   const DataLayout &DL,
-                                                   AssumptionCache *AC,
-                                                   const Instruction *CxtI,
-                                                   const DominatorTree *DT) {
+                                                   const SimplifyQuery &SQ) {
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -6407,32 +6390,29 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
   //       See simplifyICmpWithBinOpOnLHS() for candidates.
   if (match(RHS, m_URem(m_Specific(LHS), m_Value())) ||
       match(RHS, m_NUWSub(m_Specific(LHS), m_Value())))
-    if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
+    if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
       return OverflowResult::NeverOverflows;
 
   // Checking for conditions implied by dominating conditions may be expensive.
   // Limit it to usub_with_overflow calls for now.
-  if (match(CxtI,
+  if (match(SQ.CxtI,
             m_Intrinsic<Intrinsic::usub_with_overflow>(m_Value(), m_Value())))
-    if (auto C =
-            isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, CxtI, DL)) {
+    if (auto C = isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, SQ.CxtI,
+                                         SQ.DL)) {
       if (*C)
         return OverflowResult::NeverOverflows;
       return OverflowResult::AlwaysOverflowsLow;
     }
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
   return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
 }
 
 OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
                                                  const Value *RHS,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
+                                                 const SimplifyQuery &SQ) {
   // X - (X % ?)
   // The remainder of a value can't have greater magnitude than itself,
   // so the subtraction can't overflow.
@@ -6443,19 +6423,19 @@ OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
   // then determining no-overflow may allow other transforms.
   if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) ||
       match(RHS, m_NSWSub(m_Specific(LHS), m_Value())))
-    if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
+    if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
       return OverflowResult::NeverOverflows;
 
   // If LHS and RHS each have at least two sign bits, the subtraction
   // cannot overflow.
-  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
-      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
+  if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
+      ::ComputeNumSignBits(RHS, 0, SQ) > 1)
     return OverflowResult::NeverOverflows;
 
-  ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
-      LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
-  ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
-      RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
+  ConstantRange LHSRange =
+      computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
+  ConstantRange RHSRange =
+      computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
   return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
 }
 
@@ -6949,21 +6929,15 @@ bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
 }
 
 OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
+                                                 const SimplifyQuery &SQ) {
   return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
-                                       Add, DL, AC, CxtI, DT);
+                                       Add, SQ);
 }
 
 OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
                                                  const Value *RHS,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
-  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT);
+                                                 const SimplifyQuery &SQ) {
+  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
 }
 
 bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 4cb70cbdf093b36..224db19caac82ab 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -2540,8 +2540,9 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS,
   // we want to avoid this.
   auto &DL = L.getHeader()->getModule()->getDataLayout();
   bool ProvedNoOverflowAfterReassociate =
-      computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp,
-                                  DT) == llvm::OverflowResult::NeverOverflows;
+      computeOverflowForSignedSub(InvariantRHS, InvariantOp,
+                                  SimplifyQuery(DL, DT, AC, &ICmp)) ==
+      llvm::OverflowResult::NeverOverflows;
   if (!ProvedNoOverflowAfterReassociate)
     return false;
   auto *Preheader = L.getLoopPreheader();
@@ -2591,15 +2592,16 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS,
   // we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that
   // "C1 - C2" does not overflow.
   auto &DL = L.getHeader()->getModule()->getDataLayout();
+  SimplifyQuery SQ(DL, DT, AC, &ICmp);
   if (VariantSubtracted) {
     // C1 - LV < C2 --> LV > C1 - C2
-    if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp,
-                                    DT) != llvm::OverflowResult::NeverOverflows)
+    if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, SQ) !=
+        llvm::OverflowResult::NeverOverflows)
       return false;
   } else {
     // LV - C1 < C2 --> LV < C1 + C2
-    if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp,
-                                    DT) != llvm::OverflowResult::NeverOverflows)
+    if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, SQ) !=
+        llvm::OverflowResult::NeverOverflows)
       return false;
   }
   auto *Preheader = L.getLoopPreheader();
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index edc8a4956dd1c71..b1add3c42976fd6 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -641,8 +641,9 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
   // Check if the multiply could not overflow due to known ranges of the
   // input values.
   OverflowResult OR = computeOverflowForUnsignedMul(
-      FI.InnerTripCount, FI.OuterTripCount, DL, AC,
-      FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
+      FI.InnerTripCount, FI.OuterTripCount,
+      SimplifyQuery(DL, DT, AC,
+                    FI.OuterLoop->getLoopPreheader()->getTerminator()));
   if (OR != OverflowResult::MayOverflow)
     return OR;
 
diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
index 9c3e9a2fd018aca..021b03e42403cd7 100644
--- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp
@@ -372,9 +372,9 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
     // If the I-th index needs sext and the underlying add is not equipped with
     // nsw, we cannot split the add because
     //   sext(LHS + RHS) != sext(LHS) + sext(RHS).
+    SimplifyQuery SQ(*DL, DT, AC, GEP);
     if (requiresSignExtension(IndexToSplit, GEP) &&
-        computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
-            OverflowResult::NeverOverflows)
+        computeOverflowForSignedAdd(AO, SQ) != OverflowResult::NeverOverflows)
       return nullptr;
 
     Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);



More information about the llvm-commits mailing list