[Mlir-commits] [mlir] [MLIR][LLVM] Add ProfileSummary module flag support (PR #138070)

Tobias Gysi llvmlistbot at llvm.org
Thu May 1 23:33:00 PDT 2025


================
@@ -554,13 +554,277 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
   return ArrayAttr::get(mlirModule->getContext(), cgProfile);
 }
 
+static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
+                                           const llvm::Module *llvmModule,
+                                           const llvm::MDOperand &md) {
+  auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
+  if (!tupleEntry || tupleEntry->getNumOperands() != 2)
+    emitWarning(mlirModule.getLoc())
+        << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
+  return tupleEntry;
+}
+
+static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
+    ModuleOp mlirModule, const llvm::Module *llvmModule,
+    const llvm::MDOperand &md, StringRef matchKey, bool optional = false) {
+  llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
+  if (!tupleEntry)
+    return nullptr;
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != matchKey) {
+    if (!optional)
+      emitWarning(mlirModule.getLoc())
+          << "expected '" << matchKey << "' key, but found: "
+          << diagMD(tupleEntry->getOperand(0), llvmModule);
+    return nullptr;
+  }
+
+  return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
+}
+
+static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
+                                          const llvm::Module *llvmModule,
+                                          const llvm::MDOperand &md,
+                                          StringRef matchKey, uint64_t &val) {
+  auto *valMD =
+      getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
+  if (!valMD)
+    return false;
+
+  if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
+    val = cstInt->getZExtValue();
+    return true;
+  }
+
+  emitWarning(mlirModule.getLoc())
+      << "expected integer metadata value for key '" << matchKey
+      << "': " << diagMD(md, llvmModule);
+  return false;
+}
+
+static std::optional<ProfileSummaryFormatKind>
+convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
+                            const llvm::MDOperand &formatMD) {
+  auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, formatMD);
+  if (!tupleEntry)
+    return std::nullopt;
+
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != "ProfileFormat") {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'ProfileFormat' key: "
+        << diagMD(tupleEntry->getOperand(0), llvmModule);
+    return std::nullopt;
+  }
+
+  llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
+  std::optional<ProfileSummaryFormatKind> fmtKind =
+      symbolizeProfileSummaryFormatKind(valMD->getString());
+  if (!fmtKind) {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
+           "but found: "
+        << diagMD(valMD, llvmModule);
+    return std::nullopt;
+  }
+
+  return fmtKind;
+}
+
+static bool convertProfileSummaryDetailed(
+    ModuleOp mlirModule, const llvm::Module *llvmModule,
+    const llvm::MDOperand &summaryMD,
+    SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
+  auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
+  if (!tupleEntry)
+    return false;
+
+  llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+  if (!keyMD || keyMD->getString() != "DetailedSummary") {
+    emitWarning(mlirModule.getLoc())
+        << "expected 'DetailedSummary' key: "
+        << diagMD(tupleEntry->getOperand(0), llvmModule);
+    return false;
+  }
+
+  llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
+  if (!entriesMD) {
+    emitWarning(mlirModule.getLoc())
+        << "expected tuple value for 'DetailedSummary' key: "
+        << diagMD(tupleEntry->getOperand(1), llvmModule);
+    return false;
+  }
+
+  for (auto &&entry : entriesMD->operands()) {
+    llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
+    if (!entryMD || entryMD->getNumOperands() != 3) {
+      emitWarning(mlirModule.getLoc())
+          << "'DetailedSummary' entry expects 3 operands: "
+          << diagMD(entry, llvmModule);
+      return false;
+    }
+    llvm::ConstantAsMetadata *op0 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
+    llvm::ConstantAsMetadata *op1 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
+    llvm::ConstantAsMetadata *op2 =
+        dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
+
+    if (!op0 || !op1 || !op2) {
+      emitWarning(mlirModule.getLoc())
+          << "expected only integer entries in 'DetailedSummary': "
+          << diagMD(entry, llvmModule);
+      return false;
+    }
+
+    auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
+        mlirModule->getContext(),
+        cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
+        cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
+        cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
+    detailedSummary.push_back(detaildSummaryEntry);
+  }
+  return true;
+}
+
+static Attribute
+convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
+                                     const llvm::Module *llvmModule,
+                                     llvm::MDTuple *mdTuple) {
+  unsigned profileNumEntries = mdTuple->getNumOperands();
+  if (profileNumEntries < 8) {
+    emitWarning(mlirModule.getLoc())
+        << "expected at 8 entries in 'ProfileSummary': "
+        << diagMD(mdTuple, llvmModule);
+    return nullptr;
+  }
+
+  unsigned summayIdx = 0;
+  auto checkOptionalPosition = [&](const llvm::MDOperand &md,
+                                   StringRef matchKey) -> LogicalResult {
+    // Make sure we won't step over the bound of the array of summary entries.
+    // Since (non-optional) DetailedSummary always comes last, the next entry in
+    // the tuple operand array must exist.
+    if (summayIdx + 1 >= profileNumEntries) {
+      emitWarning(mlirModule.getLoc())
+          << "the last summary entry is '" << matchKey
+          << "', expected 'DetailedSummary': " << diagMD(md, llvmModule);
+      return failure();
+    }
+
+    return success();
+  };
+
+  auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                            std::optional<uint64_t> &val) -> LogicalResult {
+    if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+                                        /*optional=*/true))
+      return success();
+    if (checkOptionalPosition(md, matchKey).failed())
+      return failure();
+    uint64_t tmpVal = 0;
+    if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+                                       tmpVal))
+      return failure();
+    val = tmpVal;
+    return success();
+  };
+
+  auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+                               FloatAttr &attr) -> LogicalResult {
+    auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md,
+                                                 matchKey, /*optional=*/true);
+    if (!valMD)
+      return success();
+    if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
+      if (checkOptionalPosition(md, matchKey).failed())
+        return failure();
+      attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()),
+                            cstFP->getValueAPF());
+      return success();
+    }
+    emitWarning(mlirModule.getLoc())
+        << "expected double metadata value for key '" << matchKey
+        << "': " << diagMD(md, llvmModule);
+    return failure();
+  };
+
+  // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
+  // a fixed order: format, total count, etc.
+  SmallVector<Attribute> profileSummary;
+  std::optional<ProfileSummaryFormatKind> format = convertProfileSummaryFormat(
+      mlirModule, llvmModule, mdTuple->getOperand(summayIdx++));
+  if (!format.has_value())
+    return nullptr;
+
+  uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
+           maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "TotalCount", totalCount))
+    return nullptr;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxCount", maxCount))
+    return nullptr;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxInternalCount", maxInternalCount))
+    return nullptr;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "MaxFunctionCount", maxFunctionCount))
+    return nullptr;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "NumCounts", numCounts))
+    return nullptr;
+  if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx++),
+                                     "NumFunctions", numFunctions))
+    return nullptr;
+
+  // Handle optional keys.
+  std::optional<uint64_t> isPartialProfile;
+  if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
+                     isPartialProfile)
+          .failed())
+    return nullptr;
+  if (isPartialProfile.has_value())
+    summayIdx++;
+
+  FloatAttr partialProfileRatio;
+  if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio",
+                        partialProfileRatio)
+          .failed())
+    return nullptr;
+  if (partialProfileRatio)
+    summayIdx++;
+
+  // Handle detailed summary.
+  SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
+  if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
+                                     mdTuple->getOperand(summayIdx), detailed))
+    return nullptr;
+
+  // Build the final profile summary attribute.
+  return ModuleFlagProfileSummaryAttr::get(
+      mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
+      maxFunctionCount, numCounts, numFunctions, isPartialProfile,
+      partialProfileRatio ? partialProfileRatio : nullptr, detailed);
----------------
gysit wrote:

```suggestion
      partialProfileRatio, detailed);
```
nit: I think if it is nullptr it is nullptr ?

https://github.com/llvm/llvm-project/pull/138070


More information about the Mlir-commits mailing list