[llvm] 01909b4 - [IR] Make Module::setProfileSummary to replace an existing ProfileSummary flag.

Hiroshi Yamauchi via llvm-commits llvm-commits at lists.llvm.org
Thu May 21 11:38:51 PDT 2020


Author: Hiroshi Yamauchi
Date: 2020-05-21T11:38:39-07:00
New Revision: 01909b4e850846bb4cf5226072ccc608c68c9466

URL: https://github.com/llvm/llvm-project/commit/01909b4e850846bb4cf5226072ccc608c68c9466
DIFF: https://github.com/llvm/llvm-project/commit/01909b4e850846bb4cf5226072ccc608c68c9466.diff

LOG: [IR] Make Module::setProfileSummary to replace an existing ProfileSummary flag.

Summary:
Module::setProfileSummary currently calls addModuelFlag. This prevents from
updating the ProfileSummary metadata in the module and results in a second
ProfileSummary added instead of replacing an existing one. I don't think this is
the expected behavior. It prevents updating the ProfileSummary and it does not
make sense to have more than one. To address this, add Module::setModuleFlag and
use it from setProfileSummary.

Reviewers: davidxl

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79902

Added: 
    

Modified: 
    llvm/include/llvm/IR/Module.h
    llvm/lib/IR/Module.cpp
    llvm/unittests/IR/ModuleTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Module.h b/llvm/include/llvm/IR/Module.h
index 3895e99c10d2..3052651a3722 100644
--- a/llvm/include/llvm/IR/Module.h
+++ b/llvm/include/llvm/IR/Module.h
@@ -156,6 +156,11 @@ class Module {
   /// converted result in MFB.
   static bool isValidModFlagBehavior(Metadata *MD, ModFlagBehavior &MFB);
 
+  /// Check if the given module flag metadata represents a valid module flag,
+  /// and store the flag behavior, the key string and the value metadata.
+  static bool isValidModuleFlag(const MDNode &ModFlag, ModFlagBehavior &MFB,
+                                MDString *&Key, Metadata *&Val);
+
   struct ModuleFlagEntry {
     ModFlagBehavior Behavior;
     MDString *Key;
@@ -493,10 +498,12 @@ class Module {
   void addModuleFlag(ModFlagBehavior Behavior, StringRef Key, Constant *Val);
   void addModuleFlag(ModFlagBehavior Behavior, StringRef Key, uint32_t Val);
   void addModuleFlag(MDNode *Node);
+  /// Like addModuleFlag but replaces the old module flag if it already exists.
+  void setModuleFlag(ModFlagBehavior Behavior, StringRef Key, Metadata *Val);
 
-/// @}
-/// @name Materialization
-/// @{
+  /// @}
+  /// @name Materialization
+  /// @{
 
   /// Sets the GVMaterializer to GVM. This module must not yet have a
   /// Materializer. To reset the materializer for a module that already has one,

diff  --git a/llvm/lib/IR/Module.cpp b/llvm/lib/IR/Module.cpp
index f1acf4653de6..9ac1edb2519d 100644
--- a/llvm/lib/IR/Module.cpp
+++ b/llvm/lib/IR/Module.cpp
@@ -283,6 +283,20 @@ bool Module::isValidModFlagBehavior(Metadata *MD, ModFlagBehavior &MFB) {
   return false;
 }
 
+bool Module::isValidModuleFlag(const MDNode &ModFlag, ModFlagBehavior &MFB,
+                               MDString *&Key, Metadata *&Val) {
+  if (ModFlag.getNumOperands() < 3)
+    return false;
+  if (!isValidModFlagBehavior(ModFlag.getOperand(0), MFB))
+    return false;
+  MDString *K = dyn_cast_or_null<MDString>(ModFlag.getOperand(1));
+  if (!K)
+    return false;
+  Key = K;
+  Val = ModFlag.getOperand(2);
+  return true;
+}
+
 /// getModuleFlagsMetadata - Returns the module flags in the provided vector.
 void Module::
 getModuleFlagsMetadata(SmallVectorImpl<ModuleFlagEntry> &Flags) const {
@@ -291,13 +305,11 @@ getModuleFlagsMetadata(SmallVectorImpl<ModuleFlagEntry> &Flags) const {
 
   for (const MDNode *Flag : ModFlags->operands()) {
     ModFlagBehavior MFB;
-    if (Flag->getNumOperands() >= 3 &&
-        isValidModFlagBehavior(Flag->getOperand(0), MFB) &&
-        dyn_cast_or_null<MDString>(Flag->getOperand(1))) {
+    MDString *Key = nullptr;
+    Metadata *Val = nullptr;
+    if (isValidModuleFlag(*Flag, MFB, Key, Val)) {
       // Check the operands of the MDNode before accessing the operands.
       // The verifier will actually catch these failures.
-      MDString *Key = cast<MDString>(Flag->getOperand(1));
-      Metadata *Val = Flag->getOperand(2);
       Flags.push_back(ModuleFlagEntry(MFB, Key, Val));
     }
   }
@@ -358,6 +370,23 @@ void Module::addModuleFlag(MDNode *Node) {
   getOrInsertModuleFlagsMetadata()->addOperand(Node);
 }
 
+void Module::setModuleFlag(ModFlagBehavior Behavior, StringRef Key,
+                           Metadata *Val) {
+  NamedMDNode *ModFlags = getOrInsertModuleFlagsMetadata();
+  // Replace the flag if it already exists.
+  for (unsigned I = 0, E = ModFlags->getNumOperands(); I != E; ++I) {
+    MDNode *Flag = ModFlags->getOperand(I);
+    ModFlagBehavior MFB;
+    MDString *K = nullptr;
+    Metadata *V = nullptr;
+    if (isValidModuleFlag(*Flag, MFB, K, V) && K->getString() == Key) {
+      Flag->replaceOperandWith(2, Val);
+      return;
+    }
+  }
+  addModuleFlag(Behavior, Key, Val);
+}
+
 void Module::setDataLayout(StringRef Desc) {
   DL.reset(Desc);
 }
@@ -547,9 +576,9 @@ void Module::setCodeModel(CodeModel::Model CL) {
 
 void Module::setProfileSummary(Metadata *M, ProfileSummary::Kind Kind) {
   if (Kind == ProfileSummary::PSK_CSInstr)
-    addModuleFlag(ModFlagBehavior::Error, "CSProfileSummary", M);
+    setModuleFlag(ModFlagBehavior::Error, "CSProfileSummary", M);
   else
-    addModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M);
+    setModuleFlag(ModFlagBehavior::Error, "ProfileSummary", M);
 }
 
 Metadata *Module::getProfileSummary(bool IsCS) {

diff  --git a/llvm/unittests/IR/ModuleTest.cpp b/llvm/unittests/IR/ModuleTest.cpp
index f642b002a5eb..7b34d5d0ee55 100644
--- a/llvm/unittests/IR/ModuleTest.cpp
+++ b/llvm/unittests/IR/ModuleTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/IR/Module.h"
+#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/RandomNumberGenerator.h"
@@ -72,4 +73,52 @@ TEST(ModuleTest, randomNumberGenerator) {
                          RandomStreams[1].begin()));
 }
 
+TEST(ModuleTest, setModuleFlag) {
+  LLVMContext Context;
+  Module M("M", Context);
+  StringRef Key = "Key";
+  Metadata *Val1 = MDString::get(Context, "Val1");
+  Metadata *Val2 = MDString::get(Context, "Val2");
+  EXPECT_EQ(nullptr, M.getModuleFlag(Key));
+  M.setModuleFlag(Module::ModFlagBehavior::Error, Key, Val1);
+  EXPECT_EQ(Val1, M.getModuleFlag(Key));
+  M.setModuleFlag(Module::ModFlagBehavior::Error, Key, Val2);
+  EXPECT_EQ(Val2, M.getModuleFlag(Key));
+}
+
+const char *IRString = R"IR(
+  !llvm.module.flags = !{!0}
+
+  !0 = !{i32 1, !"ProfileSummary", !1}
+  !1 = !{!2, !3, !4, !5, !6, !7, !8, !9}
+  !2 = !{!"ProfileFormat", !"SampleProfile"}
+  !3 = !{!"TotalCount", i64 10000}
+  !4 = !{!"MaxCount", i64 10}
+  !5 = !{!"MaxInternalCount", i64 1}
+  !6 = !{!"MaxFunctionCount", i64 1000}
+  !7 = !{!"NumCounts", i64 200}
+  !8 = !{!"NumFunctions", i64 3}
+  !9 = !{!"DetailedSummary", !10}
+  !10 = !{!11, !12, !13}
+  !11 = !{i32 10000, i64 1000, i32 1}
+  !12 = !{i32 990000, i64 300, i32 10}
+  !13 = !{i32 999999, i64 5, i32 100}
+)IR";
+
+TEST(ModuleTest, setProfileSummary) {
+  SMDiagnostic Err;
+  LLVMContext Context;
+  std::unique_ptr<Module> M = parseAssemblyString(IRString, Err, Context);
+  auto *PS = ProfileSummary::getFromMD(M->getProfileSummary(/*IsCS*/ false));
+  EXPECT_NE(nullptr, PS);
+  EXPECT_EQ(false, PS->isPartialProfile());
+  PS->setPartialProfile(true);
+  M->setProfileSummary(PS->getMD(Context), ProfileSummary::PSK_Sample);
+  delete PS;
+  PS = ProfileSummary::getFromMD(M->getProfileSummary(/*IsCS*/ false));
+  EXPECT_NE(nullptr, PS);
+  EXPECT_EQ(true, PS->isPartialProfile());
+  delete PS;
+}
+
 } // end namespace


        


More information about the llvm-commits mailing list