[llvm] c9e93c8 - Add Query API for llvm.assume holding attributes

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 18 10:42:49 PST 2020


Author: Tyker
Date: 2020-02-18T19:42:07+01:00
New Revision: c9e93c84f61400d1aac7d195a0578e80bc48c69a

URL: https://github.com/llvm/llvm-project/commit/c9e93c84f61400d1aac7d195a0578e80bc48c69a
DIFF: https://github.com/llvm/llvm-project/commit/c9e93c84f61400d1aac7d195a0578e80bc48c69a.diff

LOG: Add Query API for llvm.assume holding attributes

Reviewers: jdoerfert, sstefan1, uenoku

Reviewed By: jdoerfert

Subscribers: mgorny, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72885

Added: 
    llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp

Modified: 
    llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
    llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
    llvm/unittests/Transforms/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
index 16be357af00a..27d83373e074 100644
--- a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
+++ b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h
@@ -16,6 +16,7 @@
 #ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H
 #define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H
 
+#include "llvm/IR/Attributes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/PassManager.h"
 
@@ -30,6 +31,41 @@ inline CallInst *BuildAssumeFromInst(Instruction *I) {
   return BuildAssumeFromInst(I, I->getModule());
 }
 
+/// It is possible to have multiple Value for the argument of an attribute in
+/// the same llvm.assume on the same llvm::Value. This is rare but need to be
+/// dealt with.
+enum class AssumeQuery {
+  Highest, ///< Take the highest value available.
+  Lowest,  ///< Take the lowest value available.
+};
+
+/// Query the operand bundle of an llvm.assume to find a single attribute of
+/// the specified kind applied on a specified Value.
+///
+/// This has a non-constant complexity. It should only be used when a single
+/// attribute is going to be queried.
+///
+/// Return true iff the queried attribute was found.
+/// If ArgVal is set. the argument will be stored to ArgVal.
+bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn, StringRef AttrName,
+                          uint64_t *ArgVal = nullptr,
+                          AssumeQuery AQR = AssumeQuery::Highest);
+inline bool hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn,
+                                 Attribute::AttrKind Kind,
+                                 uint64_t *ArgVal = nullptr,
+                                 AssumeQuery AQR = AssumeQuery::Highest) {
+  return hasAttributeInAssume(
+      AssumeCI, IsOn, Attribute::getNameFromAttrKind(Kind), ArgVal, AQR);
+}
+
+/// TODO: Add an function to create/fill a map from the bundle when users intend
+/// to make many 
diff erent queries on the same bundles. to be used for example
+/// in the Attributor.
+
+//===----------------------------------------------------------------------===//
+// Utilities for testing
+//===----------------------------------------------------------------------===//
+
 /// This pass will try to build an llvm.assume for every instruction in the
 /// function. Its main purpose is testing.
 struct AssumeBuilderPass : public PassInfoMixin<AssumeBuilderPass> {

diff  --git a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
index bf8742bfd699..9cae22b5d61d 100644
--- a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
+++ b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp
@@ -15,13 +15,13 @@
 
 using namespace llvm;
 
-namespace {
-
 cl::opt<bool> ShouldPreserveAllAttributes(
     "assume-preserve-all", cl::init(false), cl::Hidden,
     cl::desc("enable preservation of all attrbitues. even those that are "
              "unlikely to be usefull"));
 
+namespace {
+
 struct AssumedKnowledge {
   const char *Name;
   Value *Argument;
@@ -59,22 +59,33 @@ template <> struct DenseMapInfo<AssumedKnowledge> {
 
 namespace {
 
+/// Index of elements in the operand bundle.
+/// If the element exist it is guaranteed to be what is specified in this enum
+/// but it may not exist.
+enum BundleOpInfoElem {
+  BOIE_WasOn = 0,
+  BOIE_Argument = 1,
+};
+
 /// Deterministically compare OperandBundleDef.
 /// The ordering is:
-/// - by the name of the attribute, (doesn't change)
-/// - then by the Value of the argument, (doesn't change)
+/// - by the attribute's name aka operand bundle tag, (doesn't change)
+/// - then by the numeric Value of the argument, (doesn't change)
 /// - lastly by the Name of the current Value it WasOn. (may change)
 /// This order is deterministic and allows looking for the right kind of
 /// attribute with binary search. However finding the right WasOn needs to be
-/// done via linear search because values can get remplaced.
+/// done via linear search because values can get replaced.
 bool isLowerOpBundle(const OperandBundleDef &LHS, const OperandBundleDef &RHS) {
   auto getTuple = [](const OperandBundleDef &Op) {
     return std::make_tuple(
         Op.getTag(),
-        Op.input_size() < 2
+        Op.input_size() <= BOIE_Argument
             ? 0
-            : cast<ConstantInt>(*std::next(Op.input_begin()))->getZExtValue(),
-        Op.input_size() < 1 ? StringRef("") : (*Op.input_begin())->getName());
+            : cast<ConstantInt>(*(Op.input_begin() + BOIE_Argument))
+                  ->getZExtValue(),
+         Op.input_size() <= BOIE_WasOn
+            ? StringRef("")
+            : (*(Op.input_begin() + BOIE_WasOn))->getName());
   };
   return getTuple(LHS) < getTuple(RHS);
 }
@@ -160,6 +171,88 @@ CallInst *llvm::BuildAssumeFromInst(const Instruction *I, Module *M) {
   return Builder.build();
 }
 
+#ifndef NDEBUG
+
+static bool isExistingAttribute(StringRef Name) {
+  return StringSwitch<bool>(Name)
+#define GET_ATTR_NAMES
+#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) .Case(#DISPLAY_NAME, true)
+#include "llvm/IR/Attributes.inc"
+      .Default(false);
+}
+
+#endif
+
+bool llvm::hasAttributeInAssume(CallInst &AssumeCI, Value *IsOn,
+                                StringRef AttrName, uint64_t *ArgVal,
+                                AssumeQuery AQR) {
+  IntrinsicInst &Assume = cast<IntrinsicInst>(AssumeCI);
+  assert(Assume.getIntrinsicID() == Intrinsic::assume &&
+         "this function is intended to be used on llvm.assume");
+  assert(isExistingAttribute(AttrName) && "this attribute doesn't exist");
+  assert((ArgVal == nullptr || Attribute::doesAttrKindHaveArgument(
+                                   Attribute::getAttrKindFromName(AttrName))) &&
+         "requested value for an attribute that has no argument");
+  if (Assume.bundle_op_infos().empty())
+    return false;
+
+  CallInst::bundle_op_iterator Lookup;
+
+  /// The right attribute can be found by binary search. After this finding the
+  /// right WasOn needs to be done via linear search.
+  /// Element have been ordered by argument value so the first we find is the
+  /// one we need.
+  if (AQR == AssumeQuery::Lowest)
+    Lookup =
+        llvm::lower_bound(Assume.bundle_op_infos(), AttrName,
+                          [](const CallBase::BundleOpInfo &BOI, StringRef RHS) {
+                            assert(isExistingAttribute(BOI.Tag->getKey()) &&
+                                   "this attribute doesn't exist");
+                            return BOI.Tag->getKey() < RHS;
+                          });
+  else
+    Lookup = std::prev(
+        llvm::upper_bound(Assume.bundle_op_infos(), AttrName,
+                          [](StringRef LHS, const CallBase::BundleOpInfo &BOI) {
+                            assert(isExistingAttribute(BOI.Tag->getKey()) &&
+                                   "this attribute doesn't exist");
+                            return LHS < BOI.Tag->getKey();
+                          }));
+
+  auto getValueFromBundleOpInfo = [&Assume](const CallBase::BundleOpInfo &BOI,
+                                            unsigned Idx) {
+    assert(BOI.End - BOI.Begin > Idx && "index out of range");
+    return (Assume.op_begin() + BOI.Begin + Idx)->get();
+  };
+
+  if (Lookup == Assume.bundle_op_info_end() ||
+      Lookup->Tag->getKey() != AttrName)
+    return false;
+  if (IsOn) {
+    if (Lookup->End - Lookup->Begin < BOIE_WasOn)
+      return false;
+    while (true) {
+      if (Lookup == Assume.bundle_op_info_end() ||
+          Lookup->Tag->getKey() != AttrName)
+        return false;
+      if (getValueFromBundleOpInfo(*Lookup, BOIE_WasOn) == IsOn)
+        break;
+      if (AQR == AssumeQuery::Highest &&
+          Lookup == Assume.bundle_op_info_begin())
+        return false;
+      Lookup = Lookup + (AQR == AssumeQuery::Lowest ? 1 : -1);
+    }
+  }
+
+  if (Lookup->End - Lookup->Begin < BOIE_Argument)
+    return true;
+  if (ArgVal)
+    *ArgVal =
+        cast<ConstantInt>(getValueFromBundleOpInfo(*Lookup, BOIE_Argument))
+            ->getZExtValue();
+  return true;
+}
+
 PreservedAnalyses AssumeBuilderPass::run(Function &F,
                                          FunctionAnalysisManager &AM) {
   for (Instruction &I : instructions(F))

diff  --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt
index c9c0f9f84165..9b0d7f0f0844 100644
--- a/llvm/unittests/Transforms/Utils/CMakeLists.txt
+++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt
@@ -15,6 +15,7 @@ add_llvm_unittest(UtilsTests
   CodeMoverUtilsTest.cpp
   FunctionComparatorTest.cpp
   IntegerDivisionTest.cpp
+  KnowledgeRetentionTest.cpp
   LocalTest.cpp
   LoopRotationUtilsTest.cpp
   LoopUtilsTest.cpp

diff  --git a/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp
new file mode 100644
index 000000000000..229098d1cfc4
--- /dev/null
+++ b/llvm/unittests/Transforms/Utils/KnowledgeRetentionTest.cpp
@@ -0,0 +1,215 @@
+//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/KnowledgeRetention.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/CommandLine.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+extern cl::opt<bool> ShouldPreserveAllAttributes;
+
+static void RunTest(
+    StringRef Head, StringRef Tail,
+    std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
+        &Tests) {
+  std::string IR;
+  IR.append(Head.begin(), Head.end());
+  for (auto &Elem : Tests)
+    IR.append(Elem.first.begin(), Elem.first.end());
+  IR.append(Tail.begin(), Tail.end());
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
+  if (!Mod)
+    Err.print("AssumeQueryAPI", errs());
+  unsigned Idx = 0;
+  for (Instruction &I : (*Mod->getFunction("test")->begin())) {
+    if (Idx < Tests.size())
+      Tests[Idx].second(&I);
+    Idx++;
+  }
+}
+
+void AssertMatchesExactlyAttributes(CallInst *Assume, Value *WasOn,
+                                    StringRef AttrToMatch) {
+  Regex Reg(AttrToMatch);
+  SmallVector<StringRef, 1> Matches;
+  for (StringRef Attr : {
+#define GET_ATTR_NAMES
+#define ATTRIBUTE_ALL(ENUM_NAME, DISPLAY_NAME) StringRef(#DISPLAY_NAME),
+#include "llvm/IR/Attributes.inc"
+       }) {
+    bool ShouldHaveAttr = Reg.match(Attr, &Matches) && Matches[0] == Attr;
+    if (ShouldHaveAttr != hasAttributeInAssume(*Assume, WasOn, Attr))
+      ASSERT_TRUE(false);
+  }
+}
+
+void AssertHasTheRightValue(CallInst *Assume, Value *WasOn,
+                            Attribute::AttrKind Kind, unsigned Value, bool Both,
+                            AssumeQuery AQ = AssumeQuery::Highest) {
+  if (!Both) {
+    uint64_t ArgVal = 0;
+    ASSERT_TRUE(hasAttributeInAssume(*Assume, WasOn, Kind, &ArgVal, AQ));
+    ASSERT_EQ(ArgVal, Value);
+    return;
+  }
+  uint64_t ArgValLow = 0;
+  uint64_t ArgValHigh = 0;
+  bool ResultLow = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValLow,
+                                        AssumeQuery::Lowest);
+  bool ResultHigh = hasAttributeInAssume(*Assume, WasOn, Kind, &ArgValHigh,
+                                         AssumeQuery::Highest);
+  if (ResultLow != ResultHigh)
+    ASSERT_TRUE(false);
+  if (ArgValLow != Value || ArgValLow != ArgValHigh)
+    ASSERT_EQ(ArgValLow, Value);
+}
+
+TEST(AssumeQueryAPI, Basic) {
+  StringRef Head =
+      "declare void @llvm.assume(i1)\n"
+      "declare void @func(i32*, i32*)\n"
+      "declare void @func1(i32*, i32*, i32*, i32*)\n"
+      "declare void @func_many(i32*) \"no-jump-tables\" nounwind "
+      "\"less-precise-fpmad\" willreturn norecurse\n"
+      "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) {\n";
+  StringRef Tail = "ret void\n"
+                   "}";
+  std::vector<std::pair<StringRef, llvm::function_ref<void(Instruction *)>>>
+      Tests;
+  Tests.push_back(std::make_pair(
+      "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
+      "8 noalias %P1)\n",
+      [](Instruction *I) {
+        CallInst *Assume = BuildAssumeFromInst(I);
+        Assume->insertBefore(I);
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(0),
+                                       "(nonnull|align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
+                                       "(noalias|align)");
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Dereferenceable, 16, true);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Alignment, 4, true);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Alignment, 4, true);
+      }));
+  Tests.push_back(std::make_pair(
+      "call void @func1(i32* nonnull align 32 dereferenceable(48) %P, i32* "
+      "nonnull "
+      "align 8 dereferenceable(28) %P, i32* nonnull align 64 "
+      "dereferenceable(4) "
+      "%P, i32* nonnull align 16 dereferenceable(12) %P)\n",
+      [](Instruction *I) {
+        CallInst *Assume = BuildAssumeFromInst(I);
+        Assume->insertBefore(I);
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(0),
+                                       "(nonnull|align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
+                                       "(nonnull|align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(2),
+                                       "(nonnull|align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(3),
+                                       "(nonnull|align|dereferenceable)");
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Dereferenceable, 48, false,
+                               AssumeQuery::Highest);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Alignment, 64, false,
+                               AssumeQuery::Highest);
+        AssertHasTheRightValue(Assume, I->getOperand(1),
+                               Attribute::AttrKind::Alignment, 64, false,
+                               AssumeQuery::Highest);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Dereferenceable, 4, false,
+                               AssumeQuery::Lowest);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Alignment, 8, false,
+                               AssumeQuery::Lowest);
+        AssertHasTheRightValue(Assume, I->getOperand(1),
+                               Attribute::AttrKind::Alignment, 8, false,
+                               AssumeQuery::Lowest);
+      }));
+  Tests.push_back(std::make_pair(
+      "call void @func_many(i32* align 8 %P1) cold\n", [](Instruction *I) {
+        ShouldPreserveAllAttributes.setValue(true);
+        CallInst *Assume = BuildAssumeFromInst(I);
+        Assume->insertBefore(I);
+        AssertMatchesExactlyAttributes(
+            Assume, nullptr,
+            "(align|no-jump-tables|less-precise-fpmad|"
+            "nounwind|norecurse|willreturn|cold)");
+        ShouldPreserveAllAttributes.setValue(false);
+      }));
+  Tests.push_back(
+      std::make_pair("call void @llvm.assume(i1 true)\n", [](Instruction *I) {
+        CallInst *Assume = cast<CallInst>(I);
+        AssertMatchesExactlyAttributes(Assume, nullptr, "");
+      }));
+  Tests.push_back(std::make_pair(
+      "call void @func1(i32* readnone align 32 "
+      "dereferenceable(48) noalias %P, i32* "
+      "align 8 dereferenceable(28) %P1, i32* align 64 "
+      "dereferenceable(4) "
+      "%P2, i32* nonnull align 16 dereferenceable(12) %P3)\n",
+      [](Instruction *I) {
+        CallInst *Assume = BuildAssumeFromInst(I);
+        Assume->insertBefore(I);
+        AssertMatchesExactlyAttributes(
+            Assume, I->getOperand(0),
+            "(readnone|align|dereferenceable|noalias)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(1),
+                                       "(align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(2),
+                                       "(align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, I->getOperand(3),
+                                       "(nonnull|align|dereferenceable)");
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Alignment, 32, true);
+        AssertHasTheRightValue(Assume, I->getOperand(0),
+                               Attribute::AttrKind::Dereferenceable, 48, true);
+        AssertHasTheRightValue(Assume, I->getOperand(1),
+                               Attribute::AttrKind::Dereferenceable, 28, true);
+        AssertHasTheRightValue(Assume, I->getOperand(1),
+                               Attribute::AttrKind::Alignment, 8, true);
+        AssertHasTheRightValue(Assume, I->getOperand(2),
+                               Attribute::AttrKind::Alignment, 64, true);
+        AssertHasTheRightValue(Assume, I->getOperand(2),
+                               Attribute::AttrKind::Dereferenceable, 4, true);
+        AssertHasTheRightValue(Assume, I->getOperand(3),
+                               Attribute::AttrKind::Alignment, 16, true);
+        AssertHasTheRightValue(Assume, I->getOperand(3),
+                               Attribute::AttrKind::Dereferenceable, 12, true);
+      }));
+
+  /// Keep this test last as it modifies the function.
+  Tests.push_back(std::make_pair(
+      "call void @func(i32* nonnull align 4 dereferenceable(16) %P, i32* align "
+      "8 noalias %P1)\n",
+      [](Instruction *I) {
+        CallInst *Assume = BuildAssumeFromInst(I);
+        Assume->insertBefore(I);
+        Value *New = I->getFunction()->getArg(3);
+        Value *Old = I->getOperand(0);
+        AssertMatchesExactlyAttributes(Assume, New, "");
+        AssertMatchesExactlyAttributes(Assume, Old,
+                                       "(nonnull|align|dereferenceable)");
+        Old->replaceAllUsesWith(New);
+        AssertMatchesExactlyAttributes(Assume, New,
+                                       "(nonnull|align|dereferenceable)");
+        AssertMatchesExactlyAttributes(Assume, Old, "");
+      }));
+  RunTest(Head, Tail, Tests);
+}


        


More information about the llvm-commits mailing list