[llvm] [ValueTracking] Add a helper to detect information loss (PR #82674)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 13:06:51 PST 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/82674

>From 84d8a2afb7253437c53fe724524efccdec3a9479 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 23 Feb 2024 04:26:29 +0800
Subject: [PATCH] [ValueTracking] Add a helper class to detect information loss

---
 llvm/include/llvm/Analysis/ValueTracking.h    |  18 ++++
 .../Transforms/InstCombine/InstCombiner.h     |   9 ++
 llvm/lib/Analysis/ValueTracking.cpp           | 102 ++++++++++++++++++
 .../InstCombine/InstructionCombining.cpp      |   7 ++
 4 files changed, 136 insertions(+)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index f0d0ee554f12b2..242eb96d25621f 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -24,6 +24,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include <cassert>
 #include <cstdint>
+#include <variant>
 
 namespace llvm {
 
@@ -1195,6 +1196,23 @@ std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
                                             const Value *LHS, const Value *RHS,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
+
+/// A helper class to see whether we will lose information (KnownBits,
+/// KnownFPClass...) after replacing all uses of \p From to \p To . It will help
+/// us salvage information during transformation.
+class ValueTrackingCache final {
+  Instruction *From;
+
+  bool NoPoison;
+  bool NoUndef;
+
+  std::variant<std::monostate, KnownBits, KnownFPClass> BeforeKnown;
+
+public:
+  explicit ValueTrackingCache(Instruction *FromInst, const SimplifyQuery &SQ);
+  void detectInformationLoss(Value *To, const SimplifyQuery &SQ);
+  Instruction *getFromInst() const noexcept { return From; }
+};
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_VALUETRACKING_H
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 93090431cbb69f..6531714285108f 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -91,6 +91,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// Order of predecessors to canonicalize phi nodes towards.
   SmallDenseMap<BasicBlock *, SmallVector<BasicBlock *>, 8> PredOrder;
 
+  /// ValueTrackingCache is used for detecting information loss.
+#ifndef NDEBUG
+  ValueTrackingCache *VTC = nullptr;
+#endif
+
 public:
   InstCombiner(InstructionWorklist &Worklist, BuilderTy &Builder,
                bool MinimizeSize, AAResults *AA, AssumptionCache &AC,
@@ -402,6 +407,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
       V->takeName(&I);
 
     I.replaceAllUsesWith(V);
+#ifndef NDEBUG
+    if (&I == VTC->getFromInst())
+      VTC->detectInformationLoss(V, SQ);
+#endif
     return &I;
   }
 
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 04f317228b3ea7..be065139e99168 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -78,6 +78,7 @@
 #include <cstdint>
 #include <optional>
 #include <utility>
+#include <variant>
 
 using namespace llvm;
 using namespace llvm::PatternMatch;
@@ -87,6 +88,9 @@ using namespace llvm::PatternMatch;
 static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
                                               cl::Hidden, cl::init(20));
 
+// Checks whether we will lose information after simplification.
+static cl::opt<bool> DetectInformationLoss("detect-information-loss",
+                                           cl::Hidden, cl::init(false));
 
 /// Returns the bitwidth of the given scalar or pointer type. For vector types,
 /// returns the element type's bitwidth.
@@ -9053,3 +9057,101 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
 
   return CR;
 }
+
+#ifndef NDEBUG
+llvm::ValueTrackingCache::ValueTrackingCache(Instruction *FromInst,
+                                             const SimplifyQuery &SQ)
+    : From(FromInst) {
+  if (!DetectInformationLoss)
+    return;
+
+  NoPoison = isGuaranteedNotToBePoison(From, SQ.AC, From, SQ.DT);
+  NoUndef = isGuaranteedNotToBeUndef(From, SQ.AC, From, SQ.DT);
+
+  Type *Ty = From->getType();
+  if (Ty->isIntOrIntVectorTy() || Ty->isPtrOrPtrVectorTy()) {
+    // KnownBits
+    KnownBits Known =
+        computeKnownBits(From, /*Depth=*/0, SQ.getWithInstruction(From));
+    if (!Known.isUnknown())
+      BeforeKnown = Known;
+  } else if (Ty->isFPOrFPVectorTy()) {
+    // KnownFPClass
+    // TODO: use FMF flags
+    KnownFPClass Known = computeKnownFPClass(From, fcAllFlags, /*Depth=*/0,
+                                             SQ.getWithInstruction(From));
+    if (Known.KnownFPClasses != fcAllFlags || Known.SignBit)
+      BeforeKnown = Known;
+  }
+}
+
+void llvm::ValueTrackingCache::detectInformationLoss(Value *To,
+                                                     const SimplifyQuery &SQ) {
+  if (!DetectInformationLoss)
+    return;
+
+  Instruction *ToInst = dyn_cast<Instruction>(To);
+  if (!ToInst)
+    return;
+
+  bool Inserted = false;
+  if (!ToInst->getParent()) {
+    ToInst->insertAfter(From);
+    Inserted = true;
+  }
+
+  auto WarnOnInformationLoss = [&](StringRef Attr) {
+    errs() << "Warning: the attribute " << Attr << " got lost when simplifying "
+           << *From << " into " << *To << '\n';
+  };
+
+  // Poison
+  if (NoPoison && !isGuaranteedNotToBePoison(To, SQ.AC, ToInst, SQ.DT))
+    WarnOnInformationLoss("non-poison");
+
+  // Undef
+  if (NoUndef && !isGuaranteedNotToBeUndef(To, SQ.AC, ToInst, SQ.DT))
+    WarnOnInformationLoss("non-undef");
+
+  Type *Ty = From->getType();
+  if ((Ty->isIntOrIntVectorTy() || Ty->isPtrOrPtrVectorTy()) &&
+      std::holds_alternative<KnownBits>(BeforeKnown)) {
+    KnownBits &Before = std::get<KnownBits>(BeforeKnown);
+    KnownBits After =
+        computeKnownBits(To, /*Depth=*/0, SQ.getWithInstruction(ToInst));
+    // KnownBits of From should be a subset of KnownBits of To.
+    if (!Before.Zero.isSubsetOf(After.Zero) ||
+        !Before.One.isSubsetOf(After.One)) {
+      WarnOnInformationLoss("knownbits");
+      errs() << "Before: " << Before << '\n';
+      errs() << "After: " << After << '\n';
+    }
+    assert((Before.One & After.Zero).isZero() && "Possible miscompilation");
+    assert((Before.Zero & After.One).isZero() && "Possible miscompilation");
+  } else if (Ty->isFPOrFPVectorTy() &&
+             std::holds_alternative<KnownFPClass>(BeforeKnown)) {
+    // KnownFPClass
+    KnownFPClass &Before = std::get<KnownFPClass>(BeforeKnown);
+    // TODO: use FMF flags
+    KnownFPClass After = computeKnownFPClass(To, fcAllFlags, /*Depth=*/0,
+                                             SQ.getWithInstruction(ToInst));
+    // KnownFPClass of From should be a subset of KnownFPClass of To.
+    if ((Before.KnownFPClasses & After.KnownFPClasses) !=
+        Before.KnownFPClasses) {
+      WarnOnInformationLoss("fpclasses");
+      errs() << "Before: " << Before.KnownFPClasses << '\n';
+      errs() << "After: " << After.KnownFPClasses << '\n';
+    }
+    assert((Before.KnownFPClasses & After.KnownFPClasses) != fcNone &&
+           "Possible miscompilation");
+    if (Before.SignBit.has_value() && !After.SignBit.has_value())
+      WarnOnInformationLoss("sign");
+    assert((!Before.SignBit.has_value() || !After.SignBit.has_value() ||
+            Before.SignBit == After.SignBit) &&
+           "Possible miscompilation");
+  }
+
+  if (Inserted)
+    ToInst->removeFromParent();
+}
+#endif
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 4af455c37c788c..768e00e5e80484 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4683,6 +4683,10 @@ bool InstCombinerImpl::run() {
 #endif
     LLVM_DEBUG(raw_string_ostream SS(OrigI); I->print(SS); OrigI = SS.str(););
     LLVM_DEBUG(dbgs() << "IC: Visiting: " << OrigI << '\n');
+#ifndef NDEBUG
+    ValueTrackingCache Cache(I, SQ);
+    VTC = &Cache;
+#endif
 
     if (Instruction *Result = visit(*I)) {
       ++NumCombined;
@@ -4718,6 +4722,9 @@ bool InstCombinerImpl::run() {
         Worklist.pushUsersToWorkList(*Result);
         Worklist.push(Result);
 
+#ifndef NDEBUG
+        Cache.detectInformationLoss(Result, SQ);
+#endif
         eraseInstFromFunction(*I);
       } else {
         LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n'



More information about the llvm-commits mailing list