[llvm] 8405639 - [AssumeBundles] Add API to query a bundles from a use

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 8 04:04:46 PDT 2020


Author: Tyker
Date: 2020-03-08T12:04:23+01:00
New Revision: 84056394e97885e1d7c588992d725f188d134e63

URL: https://github.com/llvm/llvm-project/commit/84056394e97885e1d7c588992d725f188d134e63
DIFF: https://github.com/llvm/llvm-project/commit/84056394e97885e1d7c588992d725f188d134e63.diff

LOG: [AssumeBundles] Add API to query a bundles from a use

Summary: Finding what information is know about a value from a use is generally useful and can be done quickly.

Reviewers: jdoerfert

Reviewed By: jdoerfert

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D75616

Added: 
    

Modified: 
    llvm/include/llvm/IR/InstrTypes.h
    llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
    llvm/lib/IR/Instructions.cpp
    llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
    llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index 3c87257b8c02..cad3f45ed5d2 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -2097,16 +2097,14 @@ class CallBase : public Instruction {
   op_iterator populateBundleOperandInfos(ArrayRef<OperandBundleDef> Bundles,
                                          const unsigned BeginIndex);
 
+public:
   /// Return the BundleOpInfo for the operand at index OpIdx.
   ///
   /// It is an error to call this with an OpIdx that does not correspond to an
   /// bundle operand.
+  BundleOpInfo &getBundleOpInfoForOperand(unsigned OpIdx);
   const BundleOpInfo &getBundleOpInfoForOperand(unsigned OpIdx) const {
-    for (auto &BOI : bundle_op_infos())
-      if (BOI.Begin <= OpIdx && OpIdx < BOI.End)
-        return BOI;
-
-    llvm_unreachable("Did not find operand bundle for operand!");
+    return const_cast<CallBase *>(this)->getBundleOpInfoForOperand(OpIdx);
   }
 
 protected:

diff  --git a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
index fcc0dae76fe1..e5997bf3d28d 100644
--- a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
+++ b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
@@ -93,6 +93,28 @@ using RetainedKnowledgeMap = DenseMap<RetainedKnowledgeKey, MinMax>;
 /// If the IR changes the map will be outdated.
 void fillMapFromAssume(CallInst &AssumeCI, RetainedKnowledgeMap &Result);
 
+/// Represent one information held inside an operand bundle of an llvm.assume.
+/// AttrKind is the property that hold.
+/// WasOn if not null is that Value for which AttrKind holds.
+/// ArgValue is optionally an argument.
+struct RetainedKnowledge {
+  Attribute::AttrKind AttrKind = Attribute::None;
+  Value *WasOn = nullptr;
+  unsigned ArgValue = 0;
+};
+
+/// Retreive the information help by Assume on the operand at index Idx.
+/// Assume should be an llvm.assume and Idx should be in the operand bundle.
+RetainedKnowledge getKnowledgeFromOperandInAssume(CallInst &Assume,
+                                                  unsigned Idx);
+
+/// Retreive the information help by the Use U of an llvm.assume. the use should
+/// be in the operand bundle.
+inline RetainedKnowledge getKnowledgeFromUseInAssume(const Use *U) {
+  return getKnowledgeFromOperandInAssume(*cast<CallInst>(U->getUser()),
+                                         U->getOperandNo());
+}
+
 //===----------------------------------------------------------------------===//
 // Utilities for testing
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 33ad40343bef..68eed612e4bf 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -384,6 +384,53 @@ CallBase::populateBundleOperandInfos(ArrayRef<OperandBundleDef> Bundles,
   return It;
 }
 
+CallBase::BundleOpInfo &CallBase::getBundleOpInfoForOperand(unsigned OpIdx) {
+  /// When there isn't many bundles, we do a simple linear search.
+  /// Else fallback to a binary-search that use the fact that bundles usually
+  /// have similar number of argument to get faster convergence.
+  if (bundle_op_info_end() - bundle_op_info_begin() < 8) {
+    for (auto &BOI : bundle_op_infos())
+      if (BOI.Begin <= OpIdx && OpIdx < BOI.End)
+        return BOI;
+
+    llvm_unreachable("Did not find operand bundle for operand!");
+  }
+
+  assert(OpIdx >= arg_size() && "the Idx is not in the operand bundles");
+  assert(bundle_op_info_end() - bundle_op_info_begin() > 0 &&
+         OpIdx < std::prev(bundle_op_info_end())->End &&
+         "The Idx isn't in the operand bundle");
+
+  /// We need a decimal number below and to prevent using floating point numbers
+  /// we use an intergal value multiplied by this constant.
+  constexpr unsigned NumberScaling = 1024;
+
+  bundle_op_iterator Begin = bundle_op_info_begin();
+  bundle_op_iterator End = bundle_op_info_end();
+  bundle_op_iterator Current;
+
+  while (Begin != End) {
+    unsigned ScaledOperandPerBundle =
+        NumberScaling * (std::prev(End)->End - Begin->Begin) / (End - Begin);
+    Current = Begin + (((OpIdx - Begin->Begin) * NumberScaling) /
+                       ScaledOperandPerBundle);
+    if (Current >= End)
+      Current = std::prev(End);
+    assert(Current < End && Current >= Begin &&
+           "the operand bundle doesn't cover every value in the range");
+    if (OpIdx >= Current->Begin && OpIdx < Current->End)
+      break;
+    if (OpIdx >= Current->End)
+      Begin = Current + 1;
+    else
+      End = Current;
+  }
+
+  assert(OpIdx >= Current->Begin && OpIdx < Current->End &&
+         "the operand bundle doesn't cover every value in the range");
+  return *Current;
+}
+
 //===----------------------------------------------------------------------===//
 //                        CallInst Implementation
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
index f2f87f9200ed..963bd22ee006 100644
--- a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
+++ b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
@@ -288,6 +288,23 @@ void llvm::fillMapFromAssume(CallInst &AssumeCI, RetainedKnowledgeMap &Result) {
   }
 }
 
+RetainedKnowledge llvm::getKnowledgeFromOperandInAssume(CallInst &AssumeCI,
+                                                        unsigned Idx) {
+  IntrinsicInst &Assume = cast<IntrinsicInst>(AssumeCI);
+  assert(Assume.getIntrinsicID() == Intrinsic::assume &&
+         "this function is intended to be used on llvm.assume");
+  CallBase::BundleOpInfo BOI = Assume.getBundleOpInfoForOperand(Idx);
+  RetainedKnowledge Result;
+  Result.AttrKind = Attribute::getAttrKindFromName(BOI.Tag->getKey());
+  Result.WasOn = getValueFromBundleOpInfo(Assume, BOI, BOIE_WasOn);
+  if (BOI.End - BOI.Begin > BOIE_Argument)
+    Result.ArgValue =
+        cast<ConstantInt>(getValueFromBundleOpInfo(Assume, BOI, BOIE_Argument))
+            ->getZExtValue();
+
+  return Result;
+}
+
 PreservedAnalyses AssumeBuilderPass::run(Function &F,
                                          FunctionAnalysisManager &AM) {
   for (Instruction &I : instructions(F))

diff  --git a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp
index 08f2c6441645..ed37e9c76858 100644
--- a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp
+++ b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp
@@ -10,10 +10,12 @@
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/CommandLine.h"
 #include "gtest/gtest.h"
+#include <random>
 
 using namespace llvm;
 
@@ -387,3 +389,102 @@ TEST(AssumeQueryAPI, fillMapFromAssume) {
       }));
   RunTest(Head, Tail, Tests);
 }
+
+static void RunRandTest(uint64_t Seed, int Size, int MinCount, int MaxCount,
+                        unsigned MaxValue) {
+  LLVMContext C;
+  SMDiagnostic Err;
+
+  std::random_device dev;
+  std::mt19937 Rng(Seed);
+  std::uniform_int_distribution<int> DistCount(MinCount, MaxCount);
+  std::uniform_int_distribution<unsigned> DistValue(0, MaxValue);
+  std::uniform_int_distribution<unsigned> DistAttr(0,
+                                                   Attribute::EndAttrKinds - 1);
+
+  std::unique_ptr<Module> Mod = std::make_unique<Module>("AssumeQueryAPI", C);
+  if (!Mod)
+    Err.print("AssumeQueryAPI", errs());
+
+  std::vector<Type *> TypeArgs;
+  for (int i = 0; i < (Size * 2); i++)
+    TypeArgs.push_back(Type::getInt32PtrTy(C));
+  FunctionType *FuncType =
+      FunctionType::get(Type::getVoidTy(C), TypeArgs, false);
+
+  Function *F =
+      Function::Create(FuncType, GlobalValue::ExternalLinkage, "test", &*Mod);
+  BasicBlock *BB = BasicBlock::Create(C);
+  BB->insertInto(F);
+  Instruction *Ret = ReturnInst::Create(C);
+  BB->getInstList().insert(BB->begin(), Ret);
+  Function *FnAssume = Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume);
+
+  std::vector<Argument *> ShuffledArgs;
+  std::vector<bool> HasArg;
+  for (auto &Arg : F->args()) {
+    ShuffledArgs.push_back(&Arg);
+    HasArg.push_back(false);
+  }
+
+  std::shuffle(ShuffledArgs.begin(), ShuffledArgs.end(), Rng);
+
+  std::vector<OperandBundleDef> OpBundle;
+  OpBundle.reserve(Size);
+  std::vector<Value *> Args;
+  Args.reserve(2);
+  for (int i = 0; i < Size; i++) {
+    int count = DistCount(Rng);
+    int value = DistValue(Rng);
+    int attr = DistAttr(Rng);
+    std::string str;
+    raw_string_ostream ss(str);
+    ss << Attribute::getNameFromAttrKind(
+        static_cast<Attribute::AttrKind>(attr));
+    Args.clear();
+
+    if (count > 0) {
+      Args.push_back(ShuffledArgs[i]);
+      HasArg[i] = true;
+    }
+    if (count > 1)
+      Args.push_back(ConstantInt::get(Type::getInt32Ty(C), value));
+
+    OpBundle.push_back(OperandBundleDef{ss.str().c_str(), std::move(Args)});
+  }
+
+  Instruction *Assume =
+      CallInst::Create(FnAssume, ArrayRef<Value *>({ConstantInt::getTrue(C)}),
+                       std::move(OpBundle));
+  Assume->insertBefore(&F->begin()->front());
+  RetainedKnowledgeMap Map;
+  fillMapFromAssume(*cast<CallInst>(Assume), Map);
+  for (int i = 0; i < (Size * 2); i++) {
+    if (!HasArg[i])
+      continue;
+    RetainedKnowledge K =
+        getKnowledgeFromUseInAssume(&*ShuffledArgs[i]->use_begin());
+    auto LookupIt = Map.find(RetainedKnowledgeKey{K.WasOn, K.AttrKind});
+    ASSERT_TRUE(LookupIt != Map.end());
+    MinMax MM = LookupIt->second;
+    ASSERT_TRUE(MM.Min == MM.Max);
+    ASSERT_TRUE(MM.Min == K.ArgValue);
+  }
+}
+
+TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
+  // // For Fuzzing
+  // std::random_device dev;
+  // std::mt19937 Rng(dev());
+  // while (true) {
+  //   unsigned Seed = Rng();
+  //   dbgs() << Seed << "\n";
+  //   RunRandTest(Seed, 100000, 0, 2, 100);
+  // }
+  RunRandTest(23456, 4, 0, 2, 100);
+  RunRandTest(560987, 25, -3, 2, 100);
+
+  // Large bundles can lead to special cases. this is why this test is soo
+  // large.
+  RunRandTest(9876789, 100000, -0, 7, 100);
+}


        


More information about the llvm-commits mailing list