[Mlir-commits] [mlir] [MLIR][LLVM] Add ProfileSummary module flag support (PR #138070)

Bruno Cardoso Lopes llvmlistbot at llvm.org
Fri May 2 16:28:45 PDT 2025


https://github.com/bcardosolopes updated https://github.com/llvm/llvm-project/pull/138070

>From 66bba757ef8e7ad43a10e80a4540ba7bb51ecdac Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 25 Apr 2025 14:45:11 -0700
Subject: [PATCH 01/13] [MLIR][LLVM] Add ProfileSummary module flag support

Unlike "CG Profile", LLVM proper does not verify the content of the metadata,
but returns an empty one in case it's ill-formed. To that intent the importer
here does a significant amount of checks to avoid consuming bad content.
---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       |  48 ++++
 .../mlir/Dialect/LLVMIR/LLVMDialect.td        |   3 +
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      |  13 +
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  68 +++++
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 259 +++++++++++++++++-
 mlir/test/Dialect/LLVMIR/invalid.mlir         |   7 +
 .../test/Dialect/LLVMIR/module-roundtrip.mlir |  26 +-
 .../Target/LLVMIR/Import/import-failure.ll    | 124 +++++++++
 .../test/Target/LLVMIR/Import/module-flags.ll |  67 ++++-
 mlir/test/Target/LLVMIR/llvmir.mlir           |  27 ++
 10 files changed, 627 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 7d6d38ecad897..5eb66745db829 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1378,6 +1378,54 @@ def ModuleFlagCGProfileEntryAttr
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def ModuleFlagProfileSummaryDetailedAttr
+    : LLVM_Attr<"ModuleFlagProfileSummaryDetailed", "profile_summary_detailed"> {
+  let summary = "ProfileSummary detailed information";
+  let description = [{
+    Contains detailed information pertinent to "ProfileSummary" attribute.
+    A `#llvm.profile_summary` may contain several of it.
+    ```mlir
+    llvm.module_flags [ ...
+        detailed_summary = [
+        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+    ```
+  }];
+  let parameters = (ins "uint32_t":$cut_off,
+                        "uint64_t":$min_count,
+                        "uint32_t":$num_counts);
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def ModuleFlagProfileSummaryAttr
+    : LLVM_Attr<"ModuleFlagProfileSummary", "profile_summary"> {
+  let summary = "ProfileSummary module flag";
+  let description = [{
+    Describes ProfileSummary gathered data in a module. Example:
+    ```mlir
+    llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+      #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+        max_internal_count = 86427, max_function_count = 4691,
+        num_counts = 3712, num_functions = 796,
+        is_partial_profile = 0 : i64,
+        partial_profile_ratio = 0.000000e+00 : f64,
+        detailed_summary = [
+        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+    ]>>]
+    ```
+  }];
+  let parameters = (
+    ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
+        "uint64_t":$max_internal_count, "uint64_t":$max_function_count,
+        "uint64_t":$num_counts, "uint64_t":$num_functions,
+        OptionalParameter<"IntegerAttr">:$is_partial_profile,
+        OptionalParameter<"FloatAttr">:$partial_profile_ratio,
+        "ArrayAttr":$detailed_summary);
+
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // LLVM_DependentLibrariesAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
index 9f9d075a3eebf..b5ea8fc5da500 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td
@@ -92,6 +92,9 @@ def LLVM_Dialect : Dialect {
     static StringRef getModuleFlagKeyCGProfileName() {
       return "CG Profile";
     }
+    static StringRef getModuleFlagKeyProfileSummaryName() {
+      return "ProfileSummary";
+    }
 
     /// Returns `true` if the given type is compatible with the LLVM dialect.
     static bool isCompatibleType(Type);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ffde597ac83c1..ef689e3721d91 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -390,6 +390,19 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
     return success();
   }
 
+  if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
+    if (auto summaryAttr = dyn_cast<ModuleFlagProfileSummaryAttr>(value)) {
+      StringRef fmt = summaryAttr.getFormat().getValue();
+      if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf")
+        return emitError() << "'ProfileFormat' must be 'SampleProfile', "
+                              "'InstrProf' or 'CSInstrProf'";
+    } else {
+      return emitError() << "'ProfileSummary' key expects a "
+                            "'#llvm.profile_summary' attribute";
+    }
+    return success();
+  }
+
   if (isa<IntegerAttr, StringAttr>(value))
     return success();
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 35dcde2a33d41..260d61f97fce5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -300,9 +300,72 @@ convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
     }
     return llvm::MDTuple::getDistinct(context, nodes);
   }
+
   return nullptr;
 }
 
+static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
+    StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
+    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::LLVMContext &context = builder.getContext();
+  llvm::MDBuilder mdb(context);
+  SmallVector<llvm::Metadata *> summaryNodes;
+
+  auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
+                                   llvm::Type::getInt64Ty(context), val))};
+    return llvm::MDTuple::get(context, tupleNodes);
+  };
+
+  SmallVector<llvm::Metadata *> fmtNode{
+      mdb.createString("ProfileFormat"),
+      mdb.createString(summaryAttr.getFormat().getValue())};
+
+  SmallVector<llvm::Metadata *> vals = {
+      llvm::MDTuple::get(context, fmtNode),
+      getIntTuple("TotalCount", summaryAttr.getTotalCount()),
+      getIntTuple("MaxCount", summaryAttr.getMaxCount()),
+      getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
+      getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
+      getIntTuple("NumCounts", summaryAttr.getNumCounts()),
+      getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
+  };
+
+  if (summaryAttr.getIsPartialProfile())
+    vals.push_back(getIntTuple("IsPartialProfile",
+                               summaryAttr.getIsPartialProfile().getUInt()));
+
+  if (summaryAttr.getPartialProfileRatio()) {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createString("PartialProfileRatio"),
+        mdb.createConstant(llvm::ConstantFP::get(
+            llvm::Type::getDoubleTy(context),
+            summaryAttr.getPartialProfileRatio().getValue()))};
+    vals.push_back(llvm::MDTuple::get(context, tupleNodes));
+  }
+
+  SmallVector<llvm::Metadata *> detailedEntries;
+  for (auto detailedEntry :
+       summaryAttr.getDetailedSummary()
+           .getAsRange<ModuleFlagProfileSummaryDetailedAttr>()) {
+    SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())),
+        mdb.createConstant(llvm::ConstantInt::get(
+            llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))};
+    detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
+  }
+  SmallVector<llvm::Metadata *> detailedSummary{
+      mdb.createString("DetailedSummary"),
+      llvm::MDTuple::get(context, detailedEntries)};
+  vals.push_back(llvm::MDTuple::get(context, detailedSummary));
+
+  return llvm::MDNode::get(context, vals);
+}
+
 static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
                                  LLVM::ModuleTranslation &moduleTranslation) {
   llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
@@ -323,6 +386,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
                                             arrayAttr, builder,
                                             moduleTranslation);
             })
+            .Case<ModuleFlagProfileSummaryAttr>([&](auto summaryAttr) {
+              return convertModuleFlagProfileSummaryAttr(
+                  flagAttr.getKey().getValue(), summaryAttr, builder,
+                  moduleTranslation);
+            })
             .Default([](auto) { return nullptr; });
 
     assert(valueMetadata && "expected valid metadata");
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 0b77a3d23d392..fff2ae4a65f2d 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -554,13 +554,262 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
 
+static Attribute
+convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
+                                     const llvm::Module *llvmModule,
+                                     llvm::MDTuple *mdTuple) {
+  unsigned profileNumEntries = mdTuple->getNumOperands();
+  if (profileNumEntries < 8) {
+    emitWarning(mlirModule.getLoc())
+        << "expected at 8 entries in 'ProfileSummary': "
+        << diagMD(mdTuple, llvmModule);
+    return nullptr;
+  }
+
+  unsigned summayIdx = 0;
+
+  auto getMDTuple = [&](const llvm::MDOperand &md) {
+    auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
+    if (!tupleEntry || tupleEntry->getNumOperands() != 2)
+      emitWarning(mlirModule.getLoc())
+          << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
+    return tupleEntry;
+  };
+
+  auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr {
+    auto *tupleEntry = getMDTuple(formatMD);
+    if (!tupleEntry)
+      return nullptr;
+
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != "ProfileFormat") {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'ProfileFormat' key: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
+    auto formatAttr = llvm::StringSwitch<std::string>(valMD->getString())
+                          .Case("SampleProfile", "SampleProfile")
+                          .Case("InstrProf", "InstrProf")
+                          .Case("CSInstrProf", "CSInstrProf")
+                          .Default("");
+    if (formatAttr.empty()) {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
+             "but found: "
+          << diagMD(valMD, llvmModule);
+      return nullptr;
+    }
+
+    return StringAttr::get(mlirModule->getContext(), formatAttr);
+  };
+
+  auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
+                           bool optional =
+                               false) -> llvm::ConstantAsMetadata * {
+    auto *tupleEntry = getMDTuple(md);
+    if (!tupleEntry)
+      return nullptr;
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != matchKey) {
+      if (!optional)
+        emitWarning(mlirModule.getLoc())
+            << "expected '" << matchKey << "' key, but found: "
+            << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
+  };
+
+  auto checkOptionalPosition = [&](const llvm::MDOperand &md,
+                                   StringRef matchKey) -> LogicalResult {
+    // Make sure we won't step over the bound of the array of summary entries.
+    // Since (non-optional) DetailedSummary always comes last, the next entry in
+    // the tuple operand array must exist.
+    if (summayIdx + 1 >= profileNumEntries) {
+      emitWarning(mlirModule.getLoc())
+          << "the last summary entry is '" << matchKey
+          << "', expected 'DetailedSummary': " << diagMD(md, llvmModule);
+      return failure();
+    }
+
+    return success();
+  };
+
+  auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey,
+                           uint64_t &val) {
+    auto *valMD = getConstantMD(md, matchKey);
+    if (!valMD)
+      return false;
+
+    if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
+      val = cstInt->getZExtValue();
+      return true;
+    }
+
+    emitWarning(mlirModule.getLoc())
+        << "expected integer metadata value for key '" << matchKey
+        << "': " << diagMD(md, llvmModule);
+    return false;
+  };
+
+  auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                            IntegerAttr &attr) -> LogicalResult {
+    if (!getConstantMD(md, matchKey, /*optional=*/true))
+      return success();
+    if (checkOptionalPosition(md, matchKey).failed())
+      return failure();
+    uint64_t val = 0;
+    if (!getInt64Value(md, matchKey, val))
+      return failure();
+    attr =
+        IntegerAttr::get(IntegerType::get(mlirModule->getContext(), 64), val);
+    return success();
+  };
+
+  auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                               FloatAttr &attr) -> LogicalResult {
+    auto *valMD = getConstantMD(md, matchKey, /*optional=*/true);
+    if (!valMD)
+      return success();
+    if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
+      if (checkOptionalPosition(md, matchKey).failed())
+        return failure();
+      attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()),
+                            cstFP->getValueAPF());
+      return success();
+    }
+    emitWarning(mlirModule.getLoc())
+        << "expected double metadata value for key '" << matchKey
+        << "': " << diagMD(md, llvmModule);
+    return failure();
+  };
+
+  auto getSummary = [&](const llvm::MDOperand &summaryMD) -> ArrayAttr {
+    auto *tupleEntry = getMDTuple(summaryMD);
+    if (!tupleEntry)
+      return nullptr;
+
+    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+    if (!keyMD || keyMD->getString() != "DetailedSummary") {
+      emitWarning(mlirModule.getLoc())
+          << "expected 'DetailedSummary' key: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
+      return nullptr;
+    }
+
+    llvm::MDTuple *entriesMD =
+        dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
+    if (!entriesMD) {
+      emitWarning(mlirModule.getLoc())
+          << "expected tuple value for 'DetailedSummary' key: "
+          << diagMD(tupleEntry->getOperand(1), llvmModule);
+      return nullptr;
+    }
+
+    SmallVector<Attribute> detailedSummary;
+    for (auto &&entry : entriesMD->operands()) {
+      llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
+      if (!entryMD || entryMD->getNumOperands() != 3) {
+        emitWarning(mlirModule.getLoc())
+            << "'DetailedSummary' entry expects 3 operands: "
+            << diagMD(entry, llvmModule);
+        return nullptr;
+      }
+      llvm::ConstantAsMetadata *op0 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
+      llvm::ConstantAsMetadata *op1 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
+      llvm::ConstantAsMetadata *op2 =
+          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
+
+      if (!op0 || !op1 || !op2) {
+        emitWarning(mlirModule.getLoc())
+            << "expected only integer entries in 'DetailedSummary': "
+            << diagMD(entry, llvmModule);
+        return nullptr;
+      }
+
+      auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
+          mlirModule->getContext(),
+          cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
+          cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
+          cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
+      detailedSummary.push_back(detaildSummaryEntry);
+    }
+    return ArrayAttr::get(mlirModule->getContext(), detailedSummary);
+  };
+
+  // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
+  // a fixed order: format, total count, etc.
+  SmallVector<Attribute> profileSummary;
+  StringAttr format = getFormat(mdTuple->getOperand(summayIdx++));
+  if (!format)
+    return nullptr;
+
+  uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
+           maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount",
+                     totalCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount",
+                     maxInternalCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount",
+                     maxFunctionCount))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts))
+    return nullptr;
+  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions",
+                     numFunctions))
+    return nullptr;
+
+  // Handle optional keys.
+  IntegerAttr isPartialProfile;
+  if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
+                     isPartialProfile)
+          .failed())
+    return nullptr;
+  if (isPartialProfile)
+    summayIdx++;
+
+  FloatAttr partialProfileRatio;
+  if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio",
+                        partialProfileRatio)
+          .failed())
+    return nullptr;
+  if (partialProfileRatio)
+    summayIdx++;
+
+  // Handle detailed summary.
+  ArrayAttr detailedSummary = getSummary(mdTuple->getOperand(summayIdx));
+  if (!detailedSummary)
+    return nullptr;
+
+  // Build the final profile summary attribute.
+  return ModuleFlagProfileSummaryAttr::get(
+      mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount,
+      maxFunctionCount, numCounts, numFunctions,
+      isPartialProfile ? isPartialProfile : nullptr,
+      partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
+}
+
 /// Invoke specific handlers for each known module flag value, returns nullptr
 /// if the key is unknown or unimplemented.
-static Attribute convertModuleFlagValueFromMDTuple(ModuleOp mlirModule,
-                                                   StringRef key,
-                                                   llvm::MDTuple *mdTuple) {
+static Attribute
+convertModuleFlagValueFromMDTuple(ModuleOp mlirModule,
+                                  const llvm::Module *llvmModule, StringRef key,
+                                  llvm::MDTuple *mdTuple) {
   if (key == LLVMDialect::getModuleFlagKeyCGProfileName())
     return convertCGProfileModuleFlagValue(mlirModule, mdTuple);
+  if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName())
+    return convertProfileSummaryModuleFlagValue(mlirModule, llvmModule,
+                                                mdTuple);
   return nullptr;
 }
 
@@ -576,8 +825,8 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() {
     } else if (auto *mdString = dyn_cast<llvm::MDString>(val)) {
       valAttr = builder.getStringAttr(mdString->getString());
     } else if (auto *mdTuple = dyn_cast<llvm::MDTuple>(val)) {
-      valAttr = convertModuleFlagValueFromMDTuple(mlirModule, key->getString(),
-                                                  mdTuple);
+      valAttr = convertModuleFlagValueFromMDTuple(mlirModule, llvmModule.get(),
+                                                  key->getString(), mdTuple);
     }
 
     if (!valAttr) {
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 5dea94026b248..84c0d40c8b346 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1800,6 +1800,13 @@ module {
 
 // -----
 
+module {
+  // expected-error at below {{'ProfileSummary' key expects a '#llvm.profile_summary' attribute}}
+  llvm.module_flags [#llvm.mlir.module_flag<append, "ProfileSummary", 3 : i64>]
+}
+
+// -----
+
 llvm.func @t0() -> !llvm.ptr {
   %0 = llvm.blockaddress <function = @t0, tag = <id = 1>> : !llvm.ptr
   llvm.blocktag <id = 1>
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index 025d9b2287c42..62a16de6b6d97 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -11,7 +11,17 @@ module {
                        #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
                        #llvm.cgprofile_entry<from = @from, count = 222>,
                        #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
-                    ]>]
+                    ]>,
+                    #llvm.mlir.module_flag<error, "ProfileSummary",
+                       #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+                         max_internal_count = 86427, max_function_count = 4691,
+                         num_counts = 3712, num_functions = 796,
+                         is_partial_profile = 0 : i64,
+                         partial_profile_ratio = 0.000000e+00 : f64,
+                         detailed_summary = [
+                           #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+                           #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+                    ]>>]
 }
 
 // CHECK: llvm.module_flags [
@@ -25,4 +35,16 @@ module {
 // CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
 // CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
 // CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
-// CHECK-SAME: ]>]
+// CHECK-SAME: ]>,
+// CHECK-SAME: #llvm.mlir.module_flag<error, "ProfileSummary",
+// CHECK-SAME:    #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+// CHECK-SAME:      max_internal_count = 86427, max_function_count = 4691,
+// CHECK-SAME:      num_counts = 3712, num_functions = 796,
+// CHECK-SAME:      is_partial_profile = 0 : i64,
+// CHECK-SAME:      partial_profile_ratio = 0.000000e+00 : f64,
+// CHECK-SAME:      detailed_summary = [
+// CHECK-SAME:        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+// CHECK-SAME:        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+// CHECK-SAME: ]>>]
+
+llvm.module_flags []
diff --git a/mlir/test/Target/LLVMIR/Import/import-failure.ll b/mlir/test/Target/LLVMIR/Import/import-failure.ll
index 782925a0a938e..7571158a57d14 100644
--- a/mlir/test/Target/LLVMIR/Import/import-failure.ll
+++ b/mlir/test/Target/LLVMIR/Import/import-failure.ll
@@ -348,3 +348,127 @@ define void @fn() {
 bb1:
   ret void
 }
+
+; // -----
+
+!10 = !{ i32 1, !"foo", i32 1 }
+!11 = !{ i32 4, !"bar", i32 37 }
+!12 = !{ i32 2, !"qux", i32 42 }
+; CHECK: unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1}
+!13 = !{ i32 3, !"qux", !{ !"foo", i32 1 }}
+!llvm.module.flags = !{ !10, !11, !12, !13 }
+
+; // -----
+
+!llvm.module.flags = !{!41873}
+
+!41873 = !{i32 1, !"ProfileSummary", !41874}
+!41874 = !{!41875, !41876, !41877, !41878, !41880, !41881, !41882, !41883, !41884}
+!41875 = !{!"ProfileFormat", !"InstrProf"}
+!41876 = !{!"TotalCount", i64 263646}
+!41877 = !{!"MaxCount", i64 86427}
+!41878 = !{!"MaxInternalCount", i64 86427}
+; CHECK: expected 'MaxFunctionCount' key, but found: !"NumCounts"
+!41880 = !{!"NumCounts", i64 3712}
+!41881 = !{!"NumFunctions", i64 796}
+!41882 = !{!"IsPartialProfile", i64 0}
+!41883 = !{!"PartialProfileRatio", double 0.000000e+00}
+!41884 = !{!"DetailedSummary", !41885}
+!41885 = !{!41886, !41887}
+!41886 = !{i32 10000, i64 86427, i32 1}
+!41887 = !{i32 100000, i64 86427, i32 1}
+
+; // -----
+
+!llvm.module.flags = !{!51873}
+
+!51873 = !{i32 1, !"ProfileSummary", !51874}
+!51874 = !{!51875, !51876, !51877, !51878, !51879, !51880, !51881, !51882, !51883, !51884}
+!51875 = !{!"ProfileFormat", !"InstrProf"}
+!51876 = !{!"TotalCount", i64 263646}
+!51877 = !{!"MaxCount", i64 86427}
+!51878 = !{!"MaxInternalCount", i64 86427}
+!51879 = !{!"MaxFunctionCount", i64 4691}
+!51880 = !{!"NumCounts", i64 3712}
+; CHECK: expected integer metadata value for key 'NumFunctions'
+!51881 = !{!"NumFunctions", double 0.000000e+00}
+!51882 = !{!"IsPartialProfile", i64 0}
+!51883 = !{!"PartialProfileRatio", double 0.000000e+00}
+!51884 = !{!"DetailedSummary", !51885}
+!51885 = !{!51886, !51887}
+!51886 = !{i32 10000, i64 86427, i32 1}
+!51887 = !{i32 100000, i64 86427, i32 1}
+
+; // -----
+
+!llvm.module.flags = !{!61873}
+
+!61873 = !{i32 1, !"ProfileSummary", !61874}
+!61874 = !{!61875, !61876, !61877, !61878, !61879, !61880, !61881, !61882, !61883, !61884}
+; CHECK: expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, but found: !"MyThingyFmt"
+!61875 = !{!"ProfileFormat", !"MyThingyFmt"}
+!61876 = !{!"TotalCount", i64 263646}
+!61877 = !{!"MaxCount", i64 86427}
+!61878 = !{!"MaxInternalCount", i64 86427}
+!61879 = !{!"MaxFunctionCount", i64 4691}
+!61880 = !{!"NumCounts", i64 3712}
+!61881 = !{!"NumFunctions", i64 796}
+!61882 = !{!"IsPartialProfile", i64 0}
+!61883 = !{!"PartialProfileRatio", double 0.000000e+00}
+!61884 = !{!"DetailedSummary", !61885}
+!61885 = !{!61886, !61887}
+!61886 = !{i32 10000, i64 86427, i32 1}
+!61887 = !{i32 100000, i64 86427, i32 1}
+
+; // -----
+
+!llvm.module.flags = !{!71873}
+
+!71873 = !{i32 1, !"ProfileSummary", !71874}
+!71874 = !{!71875, !71876, !71877, !71878, !71879, !71880, !71881, !71882, !71883}
+!71875 = !{!"ProfileFormat", !"InstrProf"}
+!71876 = !{!"TotalCount", i64 263646}
+!71877 = !{!"MaxCount", i64 86427}
+!71878 = !{!"MaxInternalCount", i64 86427}
+!71879 = !{!"MaxFunctionCount", i64 4691}
+!71880 = !{!"NumCounts", i64 3712}
+!71881 = !{!"NumFunctions", i64 796}
+!71882 = !{!"IsPartialProfile", i64 0}
+; CHECK: the last summary entry is 'PartialProfileRatio', expected 'DetailedSummary'
+!71883 = !{!"PartialProfileRatio", double 0.000000e+00}
+
+; // -----
+
+!llvm.module.flags = !{!81873}
+
+!81873 = !{i32 1, !"ProfileSummary", !81874}
+; CHECK: expected at 8 entries in 'ProfileSummary'
+!81874 = !{!81875, !81876, !81877, !81878, !81879, !81880, !81881}
+!81875 = !{!"ProfileFormat", !"InstrProf"}
+!81876 = !{!"TotalCount", i64 263646}
+!81877 = !{!"MaxCount", i64 86427}
+!81878 = !{!"MaxInternalCount", i64 86427}
+!81879 = !{!"MaxFunctionCount", i64 4691}
+!81880 = !{!"NumCounts", i64 3812}
+!81881 = !{!"NumFunctions", i64 796}
+
+; // -----
+
+!llvm.module.flags = !{!91873}
+
+!91873 = !{i32 1, !"ProfileSummary", !91874}
+!91874 = !{!91875, !91876, !91877, !91878, !91879, !91880, !91881, !91882, !91883, !91884}
+!91875 = !{!"ProfileFormat", !"InstrProf"}
+; CHECK: expected 2-element tuple metadata
+!91876 = !{!"TotalCount", i64 263646, i64 263646}
+!91877 = !{!"MaxCount", i64 86427}
+!91878 = !{!"MaxInternalCount", i64 86427}
+!91879 = !{!"MaxFunctionCount", i64 4691}
+!91880 = !{!"NumCounts", i64 3712}
+!91881 = !{!"NumFunctions", i64 796}
+!91882 = !{!"IsPartialProfile", i64 0}
+!91883 = !{!"PartialProfileRatio", double 0.000000e+00}
+!91884 = !{!"DetailedSummary", !91885}
+!91885 = !{!91886, !91887}
+!91886 = !{i32 10000, i64 86427, i32 1}
+!91887 = !{i32 100000, i64 86427, i32 1}
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index 09e708de0cc93..49895c4f26241 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -18,14 +18,6 @@
 ; CHECK-SAME: #llvm.mlir.module_flag<max, "frame-pointer", 1 : i32>,
 ; CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">]
 
-; // -----
-; expected-warning at -2 {{unsupported module flag value for key 'qux' : !4 = !{!"foo", i32 1}}}
-!10 = !{ i32 1, !"foo", i32 1 }
-!11 = !{ i32 4, !"bar", i32 37 }
-!12 = !{ i32 2, !"qux", i32 42 }
-!13 = !{ i32 3, !"qux", !{ !"foo", i32 1 }}
-!llvm.module.flags = !{ !10, !11, !12, !13 }
-
 ; // -----
 
 declare void @from(i32)
@@ -44,3 +36,62 @@ declare void @to()
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @from, count = 222>,
 ; CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 ; CHECK-SAME: ]>]
+
+; // -----
+
+!llvm.module.flags = !{!31873}
+
+!31873 = !{i32 1, !"ProfileSummary", !31874}
+!31874 = !{!31875, !31876, !31877, !31878, !31879, !31880, !31881, !31882, !31883, !31884}
+!31875 = !{!"ProfileFormat", !"InstrProf"}
+!31876 = !{!"TotalCount", i64 263646}
+!31877 = !{!"MaxCount", i64 86427}
+!31878 = !{!"MaxInternalCount", i64 86427}
+!31879 = !{!"MaxFunctionCount", i64 4691}
+!31880 = !{!"NumCounts", i64 3712}
+!31881 = !{!"NumFunctions", i64 796}
+!31882 = !{!"IsPartialProfile", i64 0}
+!31883 = !{!"PartialProfileRatio", double 0.000000e+00}
+!31884 = !{!"DetailedSummary", !31885}
+!31885 = !{!31886, !31887}
+!31886 = !{i32 10000, i64 86427, i32 1}
+!31887 = !{i32 100000, i64 86427, i32 1}
+
+; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
+; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
+; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64,
+; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64,
+; CHECK-SAME: detailed_summary = [
+; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+; CHECK-SAME: ]>>]
+
+; // -----
+
+; Test optional fields
+
+!llvm.module.flags = !{!41873}
+
+!41873 = !{i32 1, !"ProfileSummary", !41874}
+!41874 = !{!41875, !41876, !41877, !41878, !41879, !41880, !41881, !41884}
+!41875 = !{!"ProfileFormat", !"InstrProf"}
+!41876 = !{!"TotalCount", i64 263646}
+!41877 = !{!"MaxCount", i64 86427}
+!41878 = !{!"MaxInternalCount", i64 86427}
+!41879 = !{!"MaxFunctionCount", i64 4691}
+!41880 = !{!"NumCounts", i64 3712}
+!41881 = !{!"NumFunctions", i64 796}
+!41884 = !{!"DetailedSummary", !41885}
+!41885 = !{!41886, !41887}
+!41886 = !{i32 10000, i64 86427, i32 1}
+!41887 = !{i32 100000, i64 86427, i32 1}
+
+; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
+; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
+; CHECK-SAME: num_counts = 3712, num_functions = 796,
+; CHECK-SAME: detailed_summary = [
+; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+; CHECK-SAME: ]>>]
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 9852c4051f0d0..dc347430eb0b7 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2883,6 +2883,33 @@ llvm.func @to()
 
 // -----
 
+llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+                       #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+                         max_internal_count = 86427, max_function_count = 4691,
+                         num_counts = 3712, num_functions = 796,
+                         is_partial_profile = 0 : i64,
+                         partial_profile_ratio = 0.000000e+00 : f64,
+                         detailed_summary = [
+                           #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+                           #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+                  ]>>]
+
+// CHECK: !llvm.module.flags = !{!0, !15}
+
+// CHECK: !0 = !{i32 1, !"ProfileSummary", !1}
+// CHECK: !1 = !{!2, !3, !4, !5, !6, !7, !8, !9, !10, !11}
+// CHECK: !2 = !{!"ProfileFormat", !"InstrProf"}
+// CHECK: !3 = !{!"TotalCount", i64 263646}
+// CHECK: !4 = !{!"MaxCount", i64 86427}
+// CHECK: !5 = !{!"MaxInternalCount", i64 86427}
+// CHECK: !6 = !{!"MaxFunctionCount", i64 4691}
+// CHECK: !7 = !{!"NumCounts", i64 3712}
+// CHECK: !8 = !{!"NumFunctions", i64 796}
+// CHECK: !9 = !{!"IsPartialProfile", i64 0}
+// CHECK: !10 = !{!"PartialProfileRatio", double 0.000000e+00}
+
+// -----
+
 module attributes {llvm.dependent_libraries = ["foo", "bar"]} {}
 
 // CHECK: !llvm.dependent-libraries =  !{![[#LIBFOO:]], ![[#LIBBAR:]]}

>From b661bd68a0d4d360486d44f0b8b231fa0b6c726f Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 14:37:48 -0700
Subject: [PATCH 02/13] add verifier test for format

---
 mlir/test/Dialect/LLVMIR/invalid.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 84c0d40c8b346..8f3fe03a0303e 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1807,6 +1807,20 @@ module {
 
 // -----
 
+// expected-error at below {{'ProfileFormat' must be 'SampleProfile', 'InstrProf' or 'CSInstrProf'}}
+llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+     #llvm.profile_summary<format = "YoloFmt", total_count = 263646, max_count = 86427,
+       max_internal_count = 86427, max_function_count = 4691,
+       num_counts = 3712, num_functions = 796,
+       is_partial_profile = 0 : i64,
+       partial_profile_ratio = 0.000000e+00 : f64,
+       detailed_summary = [
+         #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
+         #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+]>>]
+
+// -----
+
 llvm.func @t0() -> !llvm.ptr {
   %0 = llvm.blockaddress <function = @t0, tag = <id = 1>> : !llvm.ptr
   llvm.blocktag <id = 1>

>From abfcb67658004bcdd8775d08d56756d993c0decd Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 14:51:33 -0700
Subject: [PATCH 03/13] Use ArrayRefParameter

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 16 +++++++-------
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  4 +---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 21 ++++++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  8 +++----
 .../test/Dialect/LLVMIR/module-roundtrip.mlir | 16 +++++++-------
 .../test/Target/LLVMIR/Import/module-flags.ll | 16 +++++++-------
 mlir/test/Target/LLVMIR/llvmir.mlir           |  8 +++----
 7 files changed, 44 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 5eb66745db829..ef05e884b62fe 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1386,9 +1386,9 @@ def ModuleFlagProfileSummaryDetailedAttr
     A `#llvm.profile_summary` may contain several of it.
     ```mlir
     llvm.module_flags [ ...
-        detailed_summary = [
-        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
+        detailed_summary =
+        <cut_off = 10000, min_count = 86427, num_counts = 1>,
+        <cut_off = 100000, min_count = 86427, num_counts = 1>
     ```
   }];
   let parameters = (ins "uint32_t":$cut_off,
@@ -1409,10 +1409,10 @@ def ModuleFlagProfileSummaryAttr
         num_counts = 3712, num_functions = 796,
         is_partial_profile = 0 : i64,
         partial_profile_ratio = 0.000000e+00 : f64,
-        detailed_summary = [
-        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-    ]>>]
+        detailed_summary =
+          <cut_off = 10000, min_count = 86427, num_counts = 1>,
+          <cut_off = 100000, min_count = 86427, num_counts = 1>
+    >>]
     ```
   }];
   let parameters = (
@@ -1421,7 +1421,7 @@ def ModuleFlagProfileSummaryAttr
         "uint64_t":$num_counts, "uint64_t":$num_functions,
         OptionalParameter<"IntegerAttr">:$is_partial_profile,
         OptionalParameter<"FloatAttr">:$partial_profile_ratio,
-        "ArrayAttr":$detailed_summary);
+        ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
 
   let assemblyFormat = "`<` struct(params) `>`";
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 260d61f97fce5..6a523debf1b3a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -346,9 +346,7 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
   }
 
   SmallVector<llvm::Metadata *> detailedEntries;
-  for (auto detailedEntry :
-       summaryAttr.getDetailedSummary()
-           .getAsRange<ModuleFlagProfileSummaryDetailedAttr>()) {
+  for (auto detailedEntry : summaryAttr.getDetailedSummary()) {
     SmallVector<llvm::Metadata *> tupleNodes{
         mdb.createConstant(llvm::ConstantInt::get(
             llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index fff2ae4a65f2d..72ffe2d09e4f8 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -688,17 +688,19 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return failure();
   };
 
-  auto getSummary = [&](const llvm::MDOperand &summaryMD) -> ArrayAttr {
+  auto getSummary = [&](const llvm::MDOperand &summaryMD,
+                        SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr>
+                            &detailedSummary) {
     auto *tupleEntry = getMDTuple(summaryMD);
     if (!tupleEntry)
-      return nullptr;
+      return false;
 
     llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
     if (!keyMD || keyMD->getString() != "DetailedSummary") {
       emitWarning(mlirModule.getLoc())
           << "expected 'DetailedSummary' key: "
           << diagMD(tupleEntry->getOperand(0), llvmModule);
-      return nullptr;
+      return false;
     }
 
     llvm::MDTuple *entriesMD =
@@ -707,17 +709,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
       emitWarning(mlirModule.getLoc())
           << "expected tuple value for 'DetailedSummary' key: "
           << diagMD(tupleEntry->getOperand(1), llvmModule);
-      return nullptr;
+      return false;
     }
 
-    SmallVector<Attribute> detailedSummary;
     for (auto &&entry : entriesMD->operands()) {
       llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
       if (!entryMD || entryMD->getNumOperands() != 3) {
         emitWarning(mlirModule.getLoc())
             << "'DetailedSummary' entry expects 3 operands: "
             << diagMD(entry, llvmModule);
-        return nullptr;
+        return false;
       }
       llvm::ConstantAsMetadata *op0 =
           dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
@@ -730,7 +731,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
         emitWarning(mlirModule.getLoc())
             << "expected only integer entries in 'DetailedSummary': "
             << diagMD(entry, llvmModule);
-        return nullptr;
+        return false;
       }
 
       auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
@@ -740,7 +741,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
           cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
       detailedSummary.push_back(detaildSummaryEntry);
     }
-    return ArrayAttr::get(mlirModule->getContext(), detailedSummary);
+    return true;
   };
 
   // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
@@ -787,8 +788,8 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     summayIdx++;
 
   // Handle detailed summary.
-  ArrayAttr detailedSummary = getSummary(mdTuple->getOperand(summayIdx));
-  if (!detailedSummary)
+  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
+  if (!getSummary(mdTuple->getOperand(summayIdx), detailedSummary))
     return nullptr;
 
   // Build the final profile summary attribute.
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 8f3fe03a0303e..4f35358c9486a 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1814,10 +1814,10 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
        num_counts = 3712, num_functions = 796,
        is_partial_profile = 0 : i64,
        partial_profile_ratio = 0.000000e+00 : f64,
-       detailed_summary = [
-         #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-         #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-]>>]
+       detailed_summary =
+         <cut_off = 10000, min_count = 86427, num_counts = 1>,
+         <cut_off = 100000, min_count = 86427, num_counts = 1>
+>>]
 
 // -----
 
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index 62a16de6b6d97..bd6162f15527c 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -18,10 +18,10 @@ module {
                          num_counts = 3712, num_functions = 796,
                          is_partial_profile = 0 : i64,
                          partial_profile_ratio = 0.000000e+00 : f64,
-                         detailed_summary = [
-                           #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-                           #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-                    ]>>]
+                         detailed_summary =
+                           <cut_off = 10000, min_count = 86427, num_counts = 1>,
+                           <cut_off = 100000, min_count = 86427, num_counts = 1>
+                    >>]
 }
 
 // CHECK: llvm.module_flags [
@@ -42,9 +42,9 @@ module {
 // CHECK-SAME:      num_counts = 3712, num_functions = 796,
 // CHECK-SAME:      is_partial_profile = 0 : i64,
 // CHECK-SAME:      partial_profile_ratio = 0.000000e+00 : f64,
-// CHECK-SAME:      detailed_summary = [
-// CHECK-SAME:        #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-// CHECK-SAME:        #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-// CHECK-SAME: ]>>]
+// CHECK-SAME:      detailed_summary =
+// CHECK-SAME:        <cut_off = 10000, min_count = 86427, num_counts = 1>,
+// CHECK-SAME:        <cut_off = 100000, min_count = 86427, num_counts = 1>
+// CHECK-SAME: >>]
 
 llvm.module_flags []
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index 49895c4f26241..11df41a630d05 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -62,10 +62,10 @@ declare void @to()
 ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
 ; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64,
 ; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64,
-; CHECK-SAME: detailed_summary = [
-; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-; CHECK-SAME: ]>>]
+; CHECK-SAME: detailed_summary =
+; CHECK-SAME: <cut_off = 10000, min_count = 86427, num_counts = 1>,
+; CHECK-SAME: <cut_off = 100000, min_count = 86427, num_counts = 1>
+; CHECK-SAME: >>]
 
 ; // -----
 
@@ -91,7 +91,7 @@ declare void @to()
 ; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
 ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
 ; CHECK-SAME: num_counts = 3712, num_functions = 796,
-; CHECK-SAME: detailed_summary = [
-; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-; CHECK-SAME: #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-; CHECK-SAME: ]>>]
+; CHECK-SAME: detailed_summary =
+; CHECK-SAME: <cut_off = 10000, min_count = 86427, num_counts = 1>,
+; CHECK-SAME: <cut_off = 100000, min_count = 86427, num_counts = 1>
+; CHECK-SAME: >>]
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index dc347430eb0b7..a8da126e698ec 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2889,10 +2889,10 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
                          num_counts = 3712, num_functions = 796,
                          is_partial_profile = 0 : i64,
                          partial_profile_ratio = 0.000000e+00 : f64,
-                         detailed_summary = [
-                           #llvm.profile_summary_detailed<cut_off = 10000, min_count = 86427, num_counts = 1>,
-                           #llvm.profile_summary_detailed<cut_off = 100000, min_count = 86427, num_counts = 1>
-                  ]>>]
+                         detailed_summary =
+                           <cut_off = 10000, min_count = 86427, num_counts = 1>,
+                           <cut_off = 100000, min_count = 86427, num_counts = 1>
+                  >>]
 
 // CHECK: !llvm.module.flags = !{!0, !15}
 

>From b4acc60168c16ba0592e624338f4b51059d82d9a Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 15:07:35 -0700
Subject: [PATCH 04/13] Use std::optional when possible for attr params

---
 mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td |  4 ++--
 .../Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp   |  4 ++--
 mlir/lib/Target/LLVMIR/ModuleImport.cpp          | 16 +++++++---------
 mlir/test/Dialect/LLVMIR/invalid.mlir            |  2 +-
 mlir/test/Dialect/LLVMIR/module-roundtrip.mlir   |  4 ++--
 mlir/test/Target/LLVMIR/Import/module-flags.ll   |  2 +-
 mlir/test/Target/LLVMIR/llvmir.mlir              |  2 +-
 7 files changed, 16 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index ef05e884b62fe..5a037a767a75f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1407,7 +1407,7 @@ def ModuleFlagProfileSummaryAttr
       #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
         max_internal_count = 86427, max_function_count = 4691,
         num_counts = 3712, num_functions = 796,
-        is_partial_profile = 0 : i64,
+        is_partial_profile = 0,
         partial_profile_ratio = 0.000000e+00 : f64,
         detailed_summary =
           <cut_off = 10000, min_count = 86427, num_counts = 1>,
@@ -1419,7 +1419,7 @@ def ModuleFlagProfileSummaryAttr
     ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
         "uint64_t":$max_internal_count, "uint64_t":$max_function_count,
         "uint64_t":$num_counts, "uint64_t":$num_functions,
-        OptionalParameter<"IntegerAttr">:$is_partial_profile,
+        OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
         OptionalParameter<"FloatAttr">:$partial_profile_ratio,
         ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 6a523debf1b3a..37f07475b3f02 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -333,8 +333,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
   };
 
   if (summaryAttr.getIsPartialProfile())
-    vals.push_back(getIntTuple("IsPartialProfile",
-                               summaryAttr.getIsPartialProfile().getUInt()));
+    vals.push_back(
+        getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile()));
 
   if (summaryAttr.getPartialProfileRatio()) {
     SmallVector<llvm::Metadata *> tupleNodes{
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 72ffe2d09e4f8..028e3ca5d903e 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -657,16 +657,15 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   };
 
   auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
-                            IntegerAttr &attr) -> LogicalResult {
+                            std::optional<uint64_t> &val) -> LogicalResult {
     if (!getConstantMD(md, matchKey, /*optional=*/true))
       return success();
     if (checkOptionalPosition(md, matchKey).failed())
       return failure();
-    uint64_t val = 0;
-    if (!getInt64Value(md, matchKey, val))
+    uint64_t tmpVal = 0;
+    if (!getInt64Value(md, matchKey, tmpVal))
       return failure();
-    attr =
-        IntegerAttr::get(IntegerType::get(mlirModule->getContext(), 64), val);
+    val = tmpVal;
     return success();
   };
 
@@ -771,12 +770,12 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return nullptr;
 
   // Handle optional keys.
-  IntegerAttr isPartialProfile;
+  std::optional<uint64_t> isPartialProfile;
   if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
                      isPartialProfile)
           .failed())
     return nullptr;
-  if (isPartialProfile)
+  if (isPartialProfile.has_value())
     summayIdx++;
 
   FloatAttr partialProfileRatio;
@@ -795,8 +794,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   // Build the final profile summary attribute.
   return ModuleFlagProfileSummaryAttr::get(
       mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount,
-      maxFunctionCount, numCounts, numFunctions,
-      isPartialProfile ? isPartialProfile : nullptr,
+      maxFunctionCount, numCounts, numFunctions, isPartialProfile,
       partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 4f35358c9486a..bb730b28b947d 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1812,7 +1812,7 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
      #llvm.profile_summary<format = "YoloFmt", total_count = 263646, max_count = 86427,
        max_internal_count = 86427, max_function_count = 4691,
        num_counts = 3712, num_functions = 796,
-       is_partial_profile = 0 : i64,
+       is_partial_profile = 0,
        partial_profile_ratio = 0.000000e+00 : f64,
        detailed_summary =
          <cut_off = 10000, min_count = 86427, num_counts = 1>,
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index bd6162f15527c..148b1eb87fa75 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -16,7 +16,7 @@ module {
                        #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
                          max_internal_count = 86427, max_function_count = 4691,
                          num_counts = 3712, num_functions = 796,
-                         is_partial_profile = 0 : i64,
+                         is_partial_profile = 0,
                          partial_profile_ratio = 0.000000e+00 : f64,
                          detailed_summary =
                            <cut_off = 10000, min_count = 86427, num_counts = 1>,
@@ -40,7 +40,7 @@ module {
 // CHECK-SAME:    #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
 // CHECK-SAME:      max_internal_count = 86427, max_function_count = 4691,
 // CHECK-SAME:      num_counts = 3712, num_functions = 796,
-// CHECK-SAME:      is_partial_profile = 0 : i64,
+// CHECK-SAME:      is_partial_profile = 0,
 // CHECK-SAME:      partial_profile_ratio = 0.000000e+00 : f64,
 // CHECK-SAME:      detailed_summary =
 // CHECK-SAME:        <cut_off = 10000, min_count = 86427, num_counts = 1>,
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index 11df41a630d05..8e6a47921ee38 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -60,7 +60,7 @@ declare void @to()
 ; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
 ; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
 ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
-; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0 : i64,
+; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0,
 ; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64,
 ; CHECK-SAME: detailed_summary =
 ; CHECK-SAME: <cut_off = 10000, min_count = 86427, num_counts = 1>,
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index a8da126e698ec..1d4fd6b1cfd67 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2887,7 +2887,7 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
                        #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
                          max_internal_count = 86427, max_function_count = 4691,
                          num_counts = 3712, num_functions = 796,
-                         is_partial_profile = 0 : i64,
+                         is_partial_profile = 0,
                          partial_profile_ratio = 0.000000e+00 : f64,
                          detailed_summary =
                            <cut_off = 10000, min_count = 86427, num_counts = 1>,

>From def6b16f3955e6ed7383da74d422075cce32e792 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 16:32:25 -0700
Subject: [PATCH 05/13] Use enum kind for format

---
 .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td       | 16 +++++------
 mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td | 17 +++++++++++-
 mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp      |  8 +-----
 .../LLVMIR/LLVMToLLVMIRTranslation.cpp        |  3 ++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       | 27 +++++++++----------
 mlir/test/Dialect/LLVMIR/invalid.mlir         |  4 ++-
 .../test/Dialect/LLVMIR/module-roundtrip.mlir |  4 +--
 .../test/Target/LLVMIR/Import/module-flags.ll |  4 +--
 mlir/test/Target/LLVMIR/llvmir.mlir           |  2 +-
 9 files changed, 48 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 5a037a767a75f..ade2b64c108ff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1404,7 +1404,7 @@ def ModuleFlagProfileSummaryAttr
     Describes ProfileSummary gathered data in a module. Example:
     ```mlir
     llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
-      #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+      #llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
         max_internal_count = 86427, max_function_count = 4691,
         num_counts = 3712, num_functions = 796,
         is_partial_profile = 0,
@@ -1415,13 +1415,13 @@ def ModuleFlagProfileSummaryAttr
     >>]
     ```
   }];
-  let parameters = (
-    ins "StringAttr":$format, "uint64_t":$total_count, "uint64_t":$max_count,
-        "uint64_t":$max_internal_count, "uint64_t":$max_function_count,
-        "uint64_t":$num_counts, "uint64_t":$num_functions,
-        OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
-        OptionalParameter<"FloatAttr">:$partial_profile_ratio,
-        ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
+  let parameters = (ins "ProfileSummaryFormatKind":$format,
+    "uint64_t":$total_count, "uint64_t":$max_count,
+    "uint64_t":$max_internal_count, "uint64_t":$max_function_count,
+    "uint64_t":$num_counts, "uint64_t":$num_functions,
+    OptionalParameter<"std::optional<uint64_t>">:$is_partial_profile,
+    OptionalParameter<"FloatAttr">:$partial_profile_ratio,
+    ArrayRefParameter<"ModuleFlagProfileSummaryDetailedAttr">:$detailed_summary);
 
   let assemblyFormat = "`<` struct(params) `>`";
 }
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 6c0fe363d5551..7f5052948ab6c 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -823,7 +823,7 @@ def FPExceptionBehaviorAttr : LLVM_EnumAttr<
 }
 
 //===----------------------------------------------------------------------===//
-// Module Flag Behavior
+// Module Flags
 //===----------------------------------------------------------------------===//
 
 // These values must match llvm::Module::ModFlagBehavior ones.
@@ -855,6 +855,21 @@ def ModFlagBehaviorAttr : LLVM_EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
+def LLVM_ProfileSummaryFormatSampleProfile : I64EnumAttrCase<"SampleProfile",
+                                                             0>;
+def LLVM_ProfileSummaryFormatInstrProf : I64EnumAttrCase<"InstrProf", 1>;
+def LLVM_ProfileSummaryFormatCSInstrProf : I64EnumAttrCase<"CSInstrProf", 2>;
+
+def LLVM_ProfileSummaryFormatKind : I64EnumAttr<
+    "ProfileSummaryFormatKind",
+    "LLVM ProfileSummary format kinds", [
+      LLVM_ProfileSummaryFormatSampleProfile,
+      LLVM_ProfileSummaryFormatInstrProf,
+      LLVM_ProfileSummaryFormatCSInstrProf,
+    ]> {
+  let cppNamespace = "::mlir::LLVM";
+}
+
 //===----------------------------------------------------------------------===//
 // UWTableKind
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index ef689e3721d91..d5815d39b364b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -391,15 +391,9 @@ ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   }
 
   if (key == LLVMDialect::getModuleFlagKeyProfileSummaryName()) {
-    if (auto summaryAttr = dyn_cast<ModuleFlagProfileSummaryAttr>(value)) {
-      StringRef fmt = summaryAttr.getFormat().getValue();
-      if (fmt != "SampleProfile" && fmt != "InstrProf" && fmt != "CSInstrProf")
-        return emitError() << "'ProfileFormat' must be 'SampleProfile', "
-                              "'InstrProf' or 'CSInstrProf'";
-    } else {
+    if (!isa<ModuleFlagProfileSummaryAttr>(value))
       return emitError() << "'ProfileSummary' key expects a "
                             "'#llvm.profile_summary' attribute";
-    }
     return success();
   }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 37f07475b3f02..1e517ceb827ac 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -320,7 +320,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
 
   SmallVector<llvm::Metadata *> fmtNode{
       mdb.createString("ProfileFormat"),
-      mdb.createString(summaryAttr.getFormat().getValue())};
+      mdb.createString(
+          stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};
 
   SmallVector<llvm::Metadata *> vals = {
       llvm::MDTuple::get(context, fmtNode),
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 028e3ca5d903e..13cd9229846c9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -576,34 +576,32 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return tupleEntry;
   };
 
-  auto getFormat = [&](const llvm::MDOperand &formatMD) -> StringAttr {
+  auto getFormat = [&](const llvm::MDOperand &formatMD)
+      -> std::optional<ProfileSummaryFormatKind> {
     auto *tupleEntry = getMDTuple(formatMD);
     if (!tupleEntry)
-      return nullptr;
+      return std::nullopt;
 
     llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
     if (!keyMD || keyMD->getString() != "ProfileFormat") {
       emitWarning(mlirModule.getLoc())
           << "expected 'ProfileFormat' key: "
           << diagMD(tupleEntry->getOperand(0), llvmModule);
-      return nullptr;
+      return std::nullopt;
     }
 
     llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
-    auto formatAttr = llvm::StringSwitch<std::string>(valMD->getString())
-                          .Case("SampleProfile", "SampleProfile")
-                          .Case("InstrProf", "InstrProf")
-                          .Case("CSInstrProf", "CSInstrProf")
-                          .Default("");
-    if (formatAttr.empty()) {
+    std::optional<ProfileSummaryFormatKind> fmtKind =
+        symbolizeProfileSummaryFormatKind(valMD->getString());
+    if (!fmtKind) {
       emitWarning(mlirModule.getLoc())
           << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
              "but found: "
           << diagMD(valMD, llvmModule);
-      return nullptr;
+      return std::nullopt;
     }
 
-    return StringAttr::get(mlirModule->getContext(), formatAttr);
+    return fmtKind;
   };
 
   auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
@@ -746,8 +744,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
   // a fixed order: format, total count, etc.
   SmallVector<Attribute> profileSummary;
-  StringAttr format = getFormat(mdTuple->getOperand(summayIdx++));
-  if (!format)
+  std::optional<ProfileSummaryFormatKind> format =
+      getFormat(mdTuple->getOperand(summayIdx++));
+  if (!format.has_value())
     return nullptr;
 
   uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
@@ -793,7 +792,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
 
   // Build the final profile summary attribute.
   return ModuleFlagProfileSummaryAttr::get(
-      mlirModule->getContext(), format, totalCount, maxCount, maxInternalCount,
+      mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
       maxFunctionCount, numCounts, numFunctions, isPartialProfile,
       partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
 }
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index bb730b28b947d..f9ea066a63624 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1807,9 +1807,10 @@ module {
 
 // -----
 
-// expected-error at below {{'ProfileFormat' must be 'SampleProfile', 'InstrProf' or 'CSInstrProf'}}
 llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
+     // expected-error at below {{expected one of [SampleProfile, InstrProf, CSInstrProf] for LLVM ProfileSummary format kinds, got: YoloFmt}}
      #llvm.profile_summary<format = "YoloFmt", total_count = 263646, max_count = 86427,
+     // expected-error at above {{failed to parse ModuleFlagProfileSummaryAttr parameter 'format' which is to be a `ProfileSummaryFormatKind`}}
        max_internal_count = 86427, max_function_count = 4691,
        num_counts = 3712, num_functions = 796,
        is_partial_profile = 0,
@@ -1817,6 +1818,7 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
        detailed_summary =
          <cut_off = 10000, min_count = 86427, num_counts = 1>,
          <cut_off = 100000, min_count = 86427, num_counts = 1>
+      // expected-error at below {{failed to parse ModuleFlagAttr parameter}}
 >>]
 
 // -----
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index 148b1eb87fa75..3935a1f5bc621 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -13,7 +13,7 @@ module {
                        #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
                     ]>,
                     #llvm.mlir.module_flag<error, "ProfileSummary",
-                       #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+                       #llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
                          max_internal_count = 86427, max_function_count = 4691,
                          num_counts = 3712, num_functions = 796,
                          is_partial_profile = 0,
@@ -37,7 +37,7 @@ module {
 // CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
 // CHECK-SAME: ]>,
 // CHECK-SAME: #llvm.mlir.module_flag<error, "ProfileSummary",
-// CHECK-SAME:    #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+// CHECK-SAME:    #llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
 // CHECK-SAME:      max_internal_count = 86427, max_function_count = 4691,
 // CHECK-SAME:      num_counts = 3712, num_functions = 796,
 // CHECK-SAME:      is_partial_profile = 0,
diff --git a/mlir/test/Target/LLVMIR/Import/module-flags.ll b/mlir/test/Target/LLVMIR/Import/module-flags.ll
index 8e6a47921ee38..725bd14deb651 100644
--- a/mlir/test/Target/LLVMIR/Import/module-flags.ll
+++ b/mlir/test/Target/LLVMIR/Import/module-flags.ll
@@ -58,7 +58,7 @@ declare void @to()
 !31887 = !{i32 100000, i64 86427, i32 1}
 
 ; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
-; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
+; CHECK-SAME: #llvm.profile_summary<format = InstrProf, total_count = 263646,
 ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
 ; CHECK-SAME: num_counts = 3712, num_functions = 796, is_partial_profile = 0,
 ; CHECK-SAME: partial_profile_ratio = 0.000000e+00 : f64,
@@ -88,7 +88,7 @@ declare void @to()
 !41887 = !{i32 100000, i64 86427, i32 1}
 
 ; CHECK: llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
-; CHECK-SAME: #llvm.profile_summary<format = "InstrProf", total_count = 263646,
+; CHECK-SAME: #llvm.profile_summary<format = InstrProf, total_count = 263646,
 ; CHECK-SAME: max_count = 86427, max_internal_count = 86427, max_function_count = 4691,
 ; CHECK-SAME: num_counts = 3712, num_functions = 796,
 ; CHECK-SAME: detailed_summary =
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 1d4fd6b1cfd67..1b36dc9672f0c 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2884,7 +2884,7 @@ llvm.func @to()
 // -----
 
 llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
-                       #llvm.profile_summary<format = "InstrProf", total_count = 263646, max_count = 86427,
+                       #llvm.profile_summary<format = InstrProf, total_count = 263646, max_count = 86427,
                          max_internal_count = 86427, max_function_count = 4691,
                          num_counts = 3712, num_functions = 796,
                          is_partial_profile = 0,

>From d81fccef7adac2b304495bf93b58bb80f3e57127 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 16:47:17 -0700
Subject: [PATCH 06/13] Regex tests and remove auto

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp |  2 +-
 mlir/test/Target/LLVMIR/llvmir.mlir     | 30 ++++++++++++++-----------
 2 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 13cd9229846c9..91066dbc44058 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -607,7 +607,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
                            bool optional =
                                false) -> llvm::ConstantAsMetadata * {
-    auto *tupleEntry = getMDTuple(md);
+    llvm::MDTuple *tupleEntry = getMDTuple(md);
     if (!tupleEntry)
       return nullptr;
     llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 1b36dc9672f0c..854034f3ec243 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2894,19 +2894,23 @@ llvm.module_flags [#llvm.mlir.module_flag<error, "ProfileSummary",
                            <cut_off = 100000, min_count = 86427, num_counts = 1>
                   >>]
 
-// CHECK: !llvm.module.flags = !{!0, !15}
-
-// CHECK: !0 = !{i32 1, !"ProfileSummary", !1}
-// CHECK: !1 = !{!2, !3, !4, !5, !6, !7, !8, !9, !10, !11}
-// CHECK: !2 = !{!"ProfileFormat", !"InstrProf"}
-// CHECK: !3 = !{!"TotalCount", i64 263646}
-// CHECK: !4 = !{!"MaxCount", i64 86427}
-// CHECK: !5 = !{!"MaxInternalCount", i64 86427}
-// CHECK: !6 = !{!"MaxFunctionCount", i64 4691}
-// CHECK: !7 = !{!"NumCounts", i64 3712}
-// CHECK: !8 = !{!"NumFunctions", i64 796}
-// CHECK: !9 = !{!"IsPartialProfile", i64 0}
-// CHECK: !10 = !{!"PartialProfileRatio", double 0.000000e+00}
+// CHECK: !llvm.module.flags = !{![[#PSUM:]], {{.*}}}
+
+// CHECK: ![[#PSUM]] = !{i32 1, !"ProfileSummary", ![[#SUMLIST:]]}
+// CHECK: ![[#SUMLIST]] = !{![[#FMT:]], ![[#TC:]], ![[#MC:]], ![[#MIC:]], ![[#MFC:]], ![[#NC:]], ![[#NF:]], ![[#IPP:]], ![[#PPR:]], ![[#DS:]]}
+// CHECK: ![[#FMT]] = !{!"ProfileFormat", !"InstrProf"}
+// CHECK: ![[#TC]] = !{!"TotalCount", i64 263646}
+// CHECK: ![[#MC]] = !{!"MaxCount", i64 86427}
+// CHECK: ![[#MIC]] = !{!"MaxInternalCount", i64 86427}
+// CHECK: ![[#MFC]] = !{!"MaxFunctionCount", i64 4691}
+// CHECK: ![[#NC]] = !{!"NumCounts", i64 3712}
+// CHECK: ![[#NF]] = !{!"NumFunctions", i64 796}
+// CHECK: ![[#IPP]] = !{!"IsPartialProfile", i64 0}
+// CHECK: ![[#PPR]] = !{!"PartialProfileRatio", double 0.000000e+00}
+// CHECK: ![[#DS]] = !{!"DetailedSummary", ![[#DETAILED:]]}
+// CHECK: ![[#DETAILED]] = !{![[#DS0:]], ![[#DS1:]]}
+// CHECK: ![[#DS0:]] = !{i64 10000, i64 86427, i64 1}
+// CHECK: ![[#DS1:]] = !{i64 100000, i64 86427, i64 1}
 
 // -----
 

>From 0eff0a449515f8bd2f62b410108f8c80cb01f704 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Thu, 1 May 2025 16:59:27 -0700
Subject: [PATCH 07/13] move some lambdas to static local functions

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 307 +++++++++++++-----------
 1 file changed, 162 insertions(+), 145 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 91066dbc44058..edd958821a135 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -554,74 +554,152 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
 
-static Attribute
-convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
-                                     const llvm::Module *llvmModule,
-                                     llvm::MDTuple *mdTuple) {
-  unsigned profileNumEntries = mdTuple->getNumOperands();
-  if (profileNumEntries < 8) {
+static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
+                                           const llvm::Module *llvmModule,
+                                           const llvm::MDOperand &md) {
+  auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
+  if (!tupleEntry || tupleEntry->getNumOperands() != 2)
     emitWarning(mlirModule.getLoc())
-        << "expected at 8 entries in 'ProfileSummary': "
-        << diagMD(mdTuple, llvmModule);
+        << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
+  return tupleEntry;
+}
+
+static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
+    ModuleOp mlirModule, const llvm::Module *llvmModule,
+    const llvm::MDOperand &md, StringRef matchKey, bool optional = false) {
+  llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
+  if (!tupleEntry)
+    return nullptr;
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != matchKey) {
+    if (!optional)
+      emitWarning(mlirModule.getLoc())
+          << "expected '" << matchKey << "' key, but found: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
     return nullptr;
   }
 
-  unsigned summayIdx = 0;
+  return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
+}
 
-  auto getMDTuple = [&](const llvm::MDOperand &md) {
-    auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
-    if (!tupleEntry || tupleEntry->getNumOperands() != 2)
-      emitWarning(mlirModule.getLoc())
-          << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
-    return tupleEntry;
-  };
+static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
+                                          const llvm::Module *llvmModule,
+                                          const llvm::MDOperand &md,
+                                          StringRef matchKey, uint64_t &val) {
+  auto *valMD =
+      getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
+  if (!valMD)
+    return false;
 
-  auto getFormat = [&](const llvm::MDOperand &formatMD)
-      -> std::optional<ProfileSummaryFormatKind> {
-    auto *tupleEntry = getMDTuple(formatMD);
-    if (!tupleEntry)
-      return std::nullopt;
+  if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
+    val = cstInt->getZExtValue();
+    return true;
+  }
+
+  emitWarning(mlirModule.getLoc())
+      << "expected integer metadata value for key '" << matchKey
+      << "': " << diagMD(md, llvmModule);
+  return false;
+}
+
+static std::optional<ProfileSummaryFormatKind>
+convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
+                            const llvm::MDOperand &formatMD) {
+  auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, formatMD);
+  if (!tupleEntry)
+    return std::nullopt;
+
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != "ProfileFormat") {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'ProfileFormat' key: "
+        << diagMD(tupleEntry->getOperand(0), llvmModule);
+    return std::nullopt;
+  }
+
+  llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
+  std::optional<ProfileSummaryFormatKind> fmtKind =
+      symbolizeProfileSummaryFormatKind(valMD->getString());
+  if (!fmtKind) {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
+           "but found: "
+        << diagMD(valMD, llvmModule);
+    return std::nullopt;
+  }
+
+  return fmtKind;
+}
+
+static bool convertProfileSummaryDetailed(
+    ModuleOp mlirModule, const llvm::Module *llvmModule,
+    const llvm::MDOperand &summaryMD,
+    SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
+  auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
+  if (!tupleEntry)
+    return false;
+
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != "DetailedSummary") {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'DetailedSummary' key: "
+        << diagMD(tupleEntry->getOperand(0), llvmModule);
+    return false;
+  }
 
-    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
-    if (!keyMD || keyMD->getString() != "ProfileFormat") {
+  llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
+  if (!entriesMD) {
+    emitWarning(mlirModule.getLoc())
+        << "expected tuple value for 'DetailedSummary' key: "
+        << diagMD(tupleEntry->getOperand(1), llvmModule);
+    return false;
+  }
+
+  for (auto &&entry : entriesMD->operands()) {
+    llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
+    if (!entryMD || entryMD->getNumOperands() != 3) {
       emitWarning(mlirModule.getLoc())
-          << "expected 'ProfileFormat' key: "
-          << diagMD(tupleEntry->getOperand(0), llvmModule);
-      return std::nullopt;
+          << "'DetailedSummary' entry expects 3 operands: "
+          << diagMD(entry, llvmModule);
+      return false;
     }
-
-    llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
-    std::optional<ProfileSummaryFormatKind> fmtKind =
-        symbolizeProfileSummaryFormatKind(valMD->getString());
-    if (!fmtKind) {
+    llvm::ConstantAsMetadata *op0 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
+    llvm::ConstantAsMetadata *op1 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
+    llvm::ConstantAsMetadata *op2 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
+
+    if (!op0 || !op1 || !op2) {
       emitWarning(mlirModule.getLoc())
-          << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
-             "but found: "
-          << diagMD(valMD, llvmModule);
-      return std::nullopt;
+          << "expected only integer entries in 'DetailedSummary': "
+          << diagMD(entry, llvmModule);
+      return false;
     }
 
-    return fmtKind;
-  };
-
-  auto getConstantMD = [&](const llvm::MDOperand &md, StringRef matchKey,
-                           bool optional =
-                               false) -> llvm::ConstantAsMetadata * {
-    llvm::MDTuple *tupleEntry = getMDTuple(md);
-    if (!tupleEntry)
-      return nullptr;
-    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
-    if (!keyMD || keyMD->getString() != matchKey) {
-      if (!optional)
-        emitWarning(mlirModule.getLoc())
-            << "expected '" << matchKey << "' key, but found: "
-            << diagMD(tupleEntry->getOperand(0), llvmModule);
-      return nullptr;
-    }
+    auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
+        mlirModule->getContext(),
+        cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
+        cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
+        cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
+    detailedSummary.push_back(detaildSummaryEntry);
+  }
+  return true;
+}
 
-    return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
-  };
+static Attribute
+convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
+                                     const llvm::Module *llvmModule,
+                                     llvm::MDTuple *mdTuple) {
+  unsigned profileNumEntries = mdTuple->getNumOperands();
+  if (profileNumEntries < 8) {
+    emitWarning(mlirModule.getLoc())
+        << "expected at 8 entries in 'ProfileSummary': "
+        << diagMD(mdTuple, llvmModule);
+    return nullptr;
+  }
 
+  unsigned summayIdx = 0;
   auto checkOptionalPosition = [&](const llvm::MDOperand &md,
                                    StringRef matchKey) -> LogicalResult {
     // Make sure we won't step over the bound of the array of summary entries.
@@ -637,31 +715,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return success();
   };
 
-  auto getInt64Value = [&](const llvm::MDOperand &md, StringRef matchKey,
-                           uint64_t &val) {
-    auto *valMD = getConstantMD(md, matchKey);
-    if (!valMD)
-      return false;
-
-    if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
-      val = cstInt->getZExtValue();
-      return true;
-    }
-
-    emitWarning(mlirModule.getLoc())
-        << "expected integer metadata value for key '" << matchKey
-        << "': " << diagMD(md, llvmModule);
-    return false;
-  };
-
   auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
                             std::optional<uint64_t> &val) -> LogicalResult {
-    if (!getConstantMD(md, matchKey, /*optional=*/true))
+    if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+                                        /*optional=*/true))
       return success();
     if (checkOptionalPosition(md, matchKey).failed())
       return failure();
     uint64_t tmpVal = 0;
-    if (!getInt64Value(md, matchKey, tmpVal))
+    if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+                                       tmpVal))
       return failure();
     val = tmpVal;
     return success();
@@ -669,7 +732,8 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
 
   auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
                                FloatAttr &attr) -> LogicalResult {
-    auto *valMD = getConstantMD(md, matchKey, /*optional=*/true);
+    auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md,
+                                                 matchKey, /*optional=*/true);
     if (!valMD)
       return success();
     if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
@@ -685,87 +749,39 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return failure();
   };
 
-  auto getSummary = [&](const llvm::MDOperand &summaryMD,
-                        SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr>
-                            &detailedSummary) {
-    auto *tupleEntry = getMDTuple(summaryMD);
-    if (!tupleEntry)
-      return false;
-
-    llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
-    if (!keyMD || keyMD->getString() != "DetailedSummary") {
-      emitWarning(mlirModule.getLoc())
-          << "expected 'DetailedSummary' key: "
-          << diagMD(tupleEntry->getOperand(0), llvmModule);
-      return false;
-    }
-
-    llvm::MDTuple *entriesMD =
-        dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
-    if (!entriesMD) {
-      emitWarning(mlirModule.getLoc())
-          << "expected tuple value for 'DetailedSummary' key: "
-          << diagMD(tupleEntry->getOperand(1), llvmModule);
-      return false;
-    }
-
-    for (auto &&entry : entriesMD->operands()) {
-      llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
-      if (!entryMD || entryMD->getNumOperands() != 3) {
-        emitWarning(mlirModule.getLoc())
-            << "'DetailedSummary' entry expects 3 operands: "
-            << diagMD(entry, llvmModule);
-        return false;
-      }
-      llvm::ConstantAsMetadata *op0 =
-          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
-      llvm::ConstantAsMetadata *op1 =
-          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
-      llvm::ConstantAsMetadata *op2 =
-          dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
-
-      if (!op0 || !op1 || !op2) {
-        emitWarning(mlirModule.getLoc())
-            << "expected only integer entries in 'DetailedSummary': "
-            << diagMD(entry, llvmModule);
-        return false;
-      }
-
-      auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
-          mlirModule->getContext(),
-          cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
-          cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
-          cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
-      detailedSummary.push_back(detaildSummaryEntry);
-    }
-    return true;
-  };
-
   // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
   // a fixed order: format, total count, etc.
   SmallVector<Attribute> profileSummary;
-  std::optional<ProfileSummaryFormatKind> format =
-      getFormat(mdTuple->getOperand(summayIdx++));
+  std::optional<ProfileSummaryFormatKind> format = convertProfileSummaryFormat(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++));
   if (!format.has_value())
     return nullptr;
 
   uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
            maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "TotalCount",
-                     totalCount))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "TotalCount", totalCount))
     return nullptr;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxCount", maxCount))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxCount", maxCount))
     return nullptr;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxInternalCount",
-                     maxInternalCount))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxInternalCount", maxInternalCount))
     return nullptr;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "MaxFunctionCount",
-                     maxFunctionCount))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxFunctionCount", maxFunctionCount))
     return nullptr;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumCounts", numCounts))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "NumCounts", numCounts))
     return nullptr;
-  if (!getInt64Value(mdTuple->getOperand(summayIdx++), "NumFunctions",
-                     numFunctions))
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "NumFunctions", numFunctions))
     return nullptr;
 
   // Handle optional keys.
@@ -786,15 +802,16 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     summayIdx++;
 
   // Handle detailed summary.
-  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
-  if (!getSummary(mdTuple->getOperand(summayIdx), detailedSummary))
+  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
+  if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx), detailed))
     return nullptr;
 
   // Build the final profile summary attribute.
   return ModuleFlagProfileSummaryAttr::get(
       mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
       maxFunctionCount, numCounts, numFunctions, isPartialProfile,
-      partialProfileRatio ? partialProfileRatio : nullptr, detailedSummary);
+      partialProfileRatio ? partialProfileRatio : nullptr, detailed);
 }
 
 /// Invoke specific handlers for each known module flag value, returns nullptr

>From 6f203a2bf89773f7bd132c300e9d91b1a14ad718 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 15:14:05 -0700
Subject: [PATCH 08/13] Use FailureOr

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 105 +++++++++++++-----------
 1 file changed, 55 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index edd958821a135..5d4190022e5b3 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -570,7 +570,7 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
   llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
   if (!tupleEntry)
     return nullptr;
-  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  auto *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
   if (!keyMD || keyMD->getString() != matchKey) {
     if (!optional)
       emitWarning(mlirModule.getLoc())
@@ -582,24 +582,22 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
   return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
 }
 
-static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
-                                          const llvm::Module *llvmModule,
-                                          const llvm::MDOperand &md,
-                                          StringRef matchKey, uint64_t &val) {
+static FailureOr<uint64_t>
+convertInt64FromKeyValueTuple(ModuleOp mlirModule,
+                              const llvm::Module *llvmModule,
+                              const llvm::MDOperand &md, StringRef matchKey) {
   auto *valMD =
       getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
   if (!valMD)
-    return false;
+    return failure();
 
-  if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
-    val = cstInt->getZExtValue();
-    return true;
-  }
+  if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue()))
+    return cstInt->getZExtValue();
 
   emitWarning(mlirModule.getLoc())
       << "expected integer metadata value for key '" << matchKey
       << "': " << diagMD(md, llvmModule);
-  return false;
+  return failure();
 }
 
 static std::optional<ProfileSummaryFormatKind>
@@ -631,20 +629,20 @@ convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
   return fmtKind;
 }
 
-static bool convertProfileSummaryDetailed(
-    ModuleOp mlirModule, const llvm::Module *llvmModule,
-    const llvm::MDOperand &summaryMD,
-    SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
+static FailureOr<SmallVector<ModuleFlagProfileSummaryDetailedAttr>>
+convertProfileSummaryDetailed(ModuleOp mlirModule,
+                              const llvm::Module *llvmModule,
+                              const llvm::MDOperand &summaryMD) {
   auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
   if (!tupleEntry)
-    return false;
+    return failure();
 
   llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
   if (!keyMD || keyMD->getString() != "DetailedSummary") {
     emitWarning(mlirModule.getLoc())
         << "expected 'DetailedSummary' key: "
         << diagMD(tupleEntry->getOperand(0), llvmModule);
-    return false;
+    return failure();
   }
 
   llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
@@ -652,16 +650,17 @@ static bool convertProfileSummaryDetailed(
     emitWarning(mlirModule.getLoc())
         << "expected tuple value for 'DetailedSummary' key: "
         << diagMD(tupleEntry->getOperand(1), llvmModule);
-    return false;
+    return failure();
   }
 
+  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailedSummary;
   for (auto &&entry : entriesMD->operands()) {
     llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
     if (!entryMD || entryMD->getNumOperands() != 3) {
       emitWarning(mlirModule.getLoc())
           << "'DetailedSummary' entry expects 3 operands: "
           << diagMD(entry, llvmModule);
-      return false;
+      return failure();
     }
     llvm::ConstantAsMetadata *op0 =
         dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
@@ -674,7 +673,7 @@ static bool convertProfileSummaryDetailed(
       emitWarning(mlirModule.getLoc())
           << "expected only integer entries in 'DetailedSummary': "
           << diagMD(entry, llvmModule);
-      return false;
+      return failure();
     }
 
     auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
@@ -684,7 +683,7 @@ static bool convertProfileSummaryDetailed(
         cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
     detailedSummary.push_back(detaildSummaryEntry);
   }
-  return true;
+  return detailedSummary;
 }
 
 static Attribute
@@ -722,9 +721,9 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
       return success();
     if (checkOptionalPosition(md, matchKey).failed())
       return failure();
-    uint64_t tmpVal = 0;
-    if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
-                                       tmpVal))
+    FailureOr<uint64_t> tmpVal =
+        convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
+    if (failed(tmpVal))
       return failure();
     val = tmpVal;
     return success();
@@ -757,31 +756,36 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   if (!format.has_value())
     return nullptr;
 
-  uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
-           maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "TotalCount", totalCount))
+  FailureOr<uint64_t> totalCount = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "TotalCount");
+  if (failed(totalCount))
     return nullptr;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "MaxCount", maxCount))
+
+  FailureOr<uint64_t> maxCount = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "MaxCount");
+  if (failed(maxCount))
     return nullptr;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "MaxInternalCount", maxInternalCount))
+
+  FailureOr<uint64_t> maxInternalCount = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++),
+      "MaxInternalCount");
+  if (failed(maxInternalCount))
     return nullptr;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "MaxFunctionCount", maxFunctionCount))
+
+  FailureOr<uint64_t> maxFunctionCount = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++),
+      "MaxFunctionCount");
+  if (failed(maxFunctionCount))
     return nullptr;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "NumCounts", numCounts))
+
+  FailureOr<uint64_t> numCounts = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumCounts");
+  if (failed(numCounts))
     return nullptr;
-  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx++),
-                                     "NumFunctions", numFunctions))
+
+  FailureOr<uint64_t> numFunctions = convertInt64FromKeyValueTuple(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++), "NumFunctions");
+  if (failed(numFunctions))
     return nullptr;
 
   // Handle optional keys.
@@ -802,16 +806,17 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     summayIdx++;
 
   // Handle detailed summary.
-  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
-  if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
-                                     mdTuple->getOperand(summayIdx), detailed))
+  FailureOr<SmallVector<ModuleFlagProfileSummaryDetailedAttr>> detailed =
+      convertProfileSummaryDetailed(mlirModule, llvmModule,
+                                    mdTuple->getOperand(summayIdx));
+  if (failed(detailed))
     return nullptr;
 
   // Build the final profile summary attribute.
   return ModuleFlagProfileSummaryAttr::get(
-      mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
-      maxFunctionCount, numCounts, numFunctions, isPartialProfile,
-      partialProfileRatio ? partialProfileRatio : nullptr, detailed);
+      mlirModule->getContext(), *format, *totalCount, *maxCount,
+      *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions,
+      isPartialProfile, partialProfileRatio, *detailed);
 }
 
 /// Invoke specific handlers for each known module flag value, returns nullptr

>From 181fb6bf741d03ed2fc3b01e1155e187a2e986b6 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 15:52:20 -0700
Subject: [PATCH 09/13] more nits and cleanup

---
 .../Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 5 +++--
 mlir/test/Dialect/LLVMIR/module-roundtrip.mlir               | 2 --
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index 1e517ceb827ac..e57aecd13916f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -347,7 +347,8 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
   }
 
   SmallVector<llvm::Metadata *> detailedEntries;
-  for (auto detailedEntry : summaryAttr.getDetailedSummary()) {
+  for (ModuleFlagProfileSummaryDetailedAttr detailedEntry :
+       summaryAttr.getDetailedSummary()) {
     SmallVector<llvm::Metadata *> tupleNodes{
         mdb.createConstant(llvm::ConstantInt::get(
             llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
@@ -385,7 +386,7 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
                                             arrayAttr, builder,
                                             moduleTranslation);
             })
-            .Case<ModuleFlagProfileSummaryAttr>([&](auto summaryAttr) {
+            .Case([&](ModuleFlagProfileSummaryAttr summaryAttr) {
               return convertModuleFlagProfileSummaryAttr(
                   flagAttr.getKey().getValue(), summaryAttr, builder,
                   moduleTranslation);
diff --git a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
index 3935a1f5bc621..85abd57df53c8 100644
--- a/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/module-roundtrip.mlir
@@ -46,5 +46,3 @@ module {
 // CHECK-SAME:        <cut_off = 10000, min_count = 86427, num_counts = 1>,
 // CHECK-SAME:        <cut_off = 100000, min_count = 86427, num_counts = 1>
 // CHECK-SAME: >>]
-
-llvm.module_flags []

>From 9052f26607a39adad887a482db1b006f9e0c5b8f Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 16:00:01 -0700
Subject: [PATCH 10/13] add doc to helper functions

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 5d4190022e5b3..383c3963101d7 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -554,6 +554,8 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
 
+/// Extract a two element `MDTuple` from a `MDOperand`. Emit a warning in case
+/// something else is found.
 static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
                                            const llvm::Module *llvmModule,
                                            const llvm::MDOperand &md) {
@@ -564,6 +566,9 @@ static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
   return tupleEntry;
 }
 
+/// Extract a constant metadata value from a two element tuple (<key, value>).
+/// Return nullptr if requirements are not met. A warning is emitted if the
+/// `matchKey` is different from the tuple's key.
 static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
     ModuleOp mlirModule, const llvm::Module *llvmModule,
     const llvm::MDOperand &md, StringRef matchKey, bool optional = false) {
@@ -582,6 +587,9 @@ static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
   return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
 }
 
+/// Extract an integer value from a two element tuple (<key, value>).
+/// Fail if requirements are not met. A warning is emitted if the
+/// found value isn't a LLVM constant integer.
 static FailureOr<uint64_t>
 convertInt64FromKeyValueTuple(ModuleOp mlirModule,
                               const llvm::Module *llvmModule,

>From c5056ac767364f98dfc07fc74149c325ac1ae974 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 16:04:17 -0700
Subject: [PATCH 11/13] hoist type creation

---
 .../LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index e57aecd13916f..82bdc51145d1c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -347,15 +347,16 @@ static llvm::Metadata *convertModuleFlagProfileSummaryAttr(
   }
 
   SmallVector<llvm::Metadata *> detailedEntries;
+  llvm::Type *llvmInt64Type = llvm::Type::getInt64Ty(context);
   for (ModuleFlagProfileSummaryDetailedAttr detailedEntry :
        summaryAttr.getDetailedSummary()) {
     SmallVector<llvm::Metadata *> tupleNodes{
+        mdb.createConstant(
+            llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getCutOff())),
+        mdb.createConstant(
+            llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getMinCount())),
         mdb.createConstant(llvm::ConstantInt::get(
-            llvm::Type::getInt64Ty(context), detailedEntry.getCutOff())),
-        mdb.createConstant(llvm::ConstantInt::get(
-            llvm::Type::getInt64Ty(context), detailedEntry.getMinCount())),
-        mdb.createConstant(llvm::ConstantInt::get(
-            llvm::Type::getInt64Ty(context), detailedEntry.getNumCounts()))};
+            llvmInt64Type, detailedEntry.getNumCounts()))};
     detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
   }
   SmallVector<llvm::Metadata *> detailedSummary{

>From b1096b77661aa2477e528bdf9ac970409e69ec8a Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 16:13:43 -0700
Subject: [PATCH 12/13] Use FailureOr for getOptIntValue

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 383c3963101d7..4a0b4c29f9dc4 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -41,6 +41,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Support/ModRef.h"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::LLVM;
@@ -722,11 +723,13 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return success();
   };
 
-  auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
-                            std::optional<uint64_t> &val) -> LogicalResult {
+  auto getOptIntValue =
+      [&](const llvm::MDOperand &md,
+          StringRef matchKey) -> FailureOr<std::optional<uint64_t>> {
+    std::optional<uint64_t> val = std::nullopt;
     if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
                                         /*optional=*/true))
-      return success();
+      return val;
     if (checkOptionalPosition(md, matchKey).failed())
       return failure();
     FailureOr<uint64_t> tmpVal =
@@ -734,7 +737,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     if (failed(tmpVal))
       return failure();
     val = tmpVal;
-    return success();
+    return val;
   };
 
   auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
@@ -797,12 +800,11 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return nullptr;
 
   // Handle optional keys.
-  std::optional<uint64_t> isPartialProfile;
-  if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
-                     isPartialProfile)
-          .failed())
+  FailureOr<std::optional<uint64_t>> isPartialProfile =
+      getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile");
+  if (failed(isPartialProfile))
     return nullptr;
-  if (isPartialProfile.has_value())
+  if (isPartialProfile->has_value())
     summayIdx++;
 
   FloatAttr partialProfileRatio;
@@ -824,7 +826,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   return ModuleFlagProfileSummaryAttr::get(
       mlirModule->getContext(), *format, *totalCount, *maxCount,
       *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions,
-      isPartialProfile, partialProfileRatio, *detailed);
+      *isPartialProfile, partialProfileRatio, *detailed);
 }
 
 /// Invoke specific handlers for each known module flag value, returns nullptr

>From 6f0bd126cdbad229356c4c39cae2c289a9eaffe1 Mon Sep 17 00:00:00 2001
From: Bruno Cardoso Lopes <bruno.cardoso at gmail.com>
Date: Fri, 2 May 2025 16:26:31 -0700
Subject: [PATCH 13/13] Use FailureOr for getOptDoubleValue

---
 mlir/lib/Target/LLVMIR/ModuleImport.cpp | 20 +++++++++-----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 4a0b4c29f9dc4..8ca5576900772 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -740,18 +740,17 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
     return val;
   };
 
-  auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
-                               FloatAttr &attr) -> LogicalResult {
+  auto getOptDoubleValue = [&](const llvm::MDOperand &md,
+                               StringRef matchKey) -> FailureOr<FloatAttr> {
     auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md,
                                                  matchKey, /*optional=*/true);
     if (!valMD)
-      return success();
+      return FloatAttr{};
     if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
       if (checkOptionalPosition(md, matchKey).failed())
         return failure();
-      attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()),
+      return FloatAttr::get(Float64Type::get(mlirModule.getContext()),
                             cstFP->getValueAPF());
-      return success();
     }
     emitWarning(mlirModule.getLoc())
         << "expected double metadata value for key '" << matchKey
@@ -807,12 +806,11 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   if (isPartialProfile->has_value())
     summayIdx++;
 
-  FloatAttr partialProfileRatio;
-  if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio",
-                        partialProfileRatio)
-          .failed())
+  FailureOr<FloatAttr> partialProfileRatio =
+      getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio");
+  if (failed(partialProfileRatio))
     return nullptr;
-  if (partialProfileRatio)
+  if (*partialProfileRatio)
     summayIdx++;
 
   // Handle detailed summary.
@@ -826,7 +824,7 @@ convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
   return ModuleFlagProfileSummaryAttr::get(
       mlirModule->getContext(), *format, *totalCount, *maxCount,
       *maxInternalCount, *maxFunctionCount, *numCounts, *numFunctions,
-      *isPartialProfile, partialProfileRatio, *detailed);
+      *isPartialProfile, *partialProfileRatio, *detailed);
 }
 
 /// Invoke specific handlers for each known module flag value, returns nullptr



More information about the Mlir-commits mailing list