[Mlir-commits] [mlir] [MLIR][LLVM] Add ProfileSummary module flag support (PR #138070)
Tobias Gysi
llvmlistbot at llvm.org
Thu May 1 23:33:00 PDT 2025
================
@@ -554,13 +554,277 @@ static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
return ArrayAttr::get(mlirModule->getContext(), cgProfile);
}
+static llvm::MDTuple *getTwoElementMDTuple(ModuleOp mlirModule,
+ const llvm::Module *llvmModule,
+ const llvm::MDOperand &md) {
+ auto *tupleEntry = dyn_cast_or_null<llvm::MDTuple>(md);
+ if (!tupleEntry || tupleEntry->getNumOperands() != 2)
+ emitWarning(mlirModule.getLoc())
+ << "expected 2-element tuple metadata: " << diagMD(md, llvmModule);
+ return tupleEntry;
+}
+
+static llvm::ConstantAsMetadata *getConstantMDFromKeyValueTuple(
+ ModuleOp mlirModule, const llvm::Module *llvmModule,
+ const llvm::MDOperand &md, StringRef matchKey, bool optional = false) {
+ llvm::MDTuple *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, md);
+ if (!tupleEntry)
+ return nullptr;
+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+ if (!keyMD || keyMD->getString() != matchKey) {
+ if (!optional)
+ emitWarning(mlirModule.getLoc())
+ << "expected '" << matchKey << "' key, but found: "
+ << diagMD(tupleEntry->getOperand(0), llvmModule);
+ return nullptr;
+ }
+
+ return dyn_cast<llvm::ConstantAsMetadata>(tupleEntry->getOperand(1));
+}
+
+static bool convertInt64FromKeyValueTuple(ModuleOp mlirModule,
+ const llvm::Module *llvmModule,
+ const llvm::MDOperand &md,
+ StringRef matchKey, uint64_t &val) {
+ auto *valMD =
+ getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey);
+ if (!valMD)
+ return false;
+
+ if (auto *cstInt = dyn_cast<llvm::ConstantInt>(valMD->getValue())) {
+ val = cstInt->getZExtValue();
+ return true;
+ }
+
+ emitWarning(mlirModule.getLoc())
+ << "expected integer metadata value for key '" << matchKey
+ << "': " << diagMD(md, llvmModule);
+ return false;
+}
+
+static std::optional<ProfileSummaryFormatKind>
+convertProfileSummaryFormat(ModuleOp mlirModule, const llvm::Module *llvmModule,
+ const llvm::MDOperand &formatMD) {
+ auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, formatMD);
+ if (!tupleEntry)
+ return std::nullopt;
+
+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+ if (!keyMD || keyMD->getString() != "ProfileFormat") {
+ emitWarning(mlirModule.getLoc())
+ << "expected 'ProfileFormat' key: "
+ << diagMD(tupleEntry->getOperand(0), llvmModule);
+ return std::nullopt;
+ }
+
+ llvm::MDString *valMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(1));
+ std::optional<ProfileSummaryFormatKind> fmtKind =
+ symbolizeProfileSummaryFormatKind(valMD->getString());
+ if (!fmtKind) {
+ emitWarning(mlirModule.getLoc())
+ << "expected 'SampleProfile', 'InstrProf' or 'CSInstrProf' values, "
+ "but found: "
+ << diagMD(valMD, llvmModule);
+ return std::nullopt;
+ }
+
+ return fmtKind;
+}
+
+static bool convertProfileSummaryDetailed(
+ ModuleOp mlirModule, const llvm::Module *llvmModule,
+ const llvm::MDOperand &summaryMD,
+ SmallVectorImpl<ModuleFlagProfileSummaryDetailedAttr> &detailedSummary) {
+ auto *tupleEntry = getTwoElementMDTuple(mlirModule, llvmModule, summaryMD);
+ if (!tupleEntry)
+ return false;
+
+ llvm::MDString *keyMD = dyn_cast<llvm::MDString>(tupleEntry->getOperand(0));
+ if (!keyMD || keyMD->getString() != "DetailedSummary") {
+ emitWarning(mlirModule.getLoc())
+ << "expected 'DetailedSummary' key: "
+ << diagMD(tupleEntry->getOperand(0), llvmModule);
+ return false;
+ }
+
+ llvm::MDTuple *entriesMD = dyn_cast<llvm::MDTuple>(tupleEntry->getOperand(1));
+ if (!entriesMD) {
+ emitWarning(mlirModule.getLoc())
+ << "expected tuple value for 'DetailedSummary' key: "
+ << diagMD(tupleEntry->getOperand(1), llvmModule);
+ return false;
+ }
+
+ for (auto &&entry : entriesMD->operands()) {
+ llvm::MDTuple *entryMD = dyn_cast<llvm::MDTuple>(entry);
+ if (!entryMD || entryMD->getNumOperands() != 3) {
+ emitWarning(mlirModule.getLoc())
+ << "'DetailedSummary' entry expects 3 operands: "
+ << diagMD(entry, llvmModule);
+ return false;
+ }
+ llvm::ConstantAsMetadata *op0 =
+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(0));
+ llvm::ConstantAsMetadata *op1 =
+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(1));
+ llvm::ConstantAsMetadata *op2 =
+ dyn_cast<llvm::ConstantAsMetadata>(entryMD->getOperand(2));
+
+ if (!op0 || !op1 || !op2) {
+ emitWarning(mlirModule.getLoc())
+ << "expected only integer entries in 'DetailedSummary': "
+ << diagMD(entry, llvmModule);
+ return false;
+ }
+
+ auto detaildSummaryEntry = ModuleFlagProfileSummaryDetailedAttr::get(
+ mlirModule->getContext(),
+ cast<llvm::ConstantInt>(op0->getValue())->getZExtValue(),
+ cast<llvm::ConstantInt>(op1->getValue())->getZExtValue(),
+ cast<llvm::ConstantInt>(op2->getValue())->getZExtValue());
+ detailedSummary.push_back(detaildSummaryEntry);
+ }
+ return true;
+}
+
+static Attribute
+convertProfileSummaryModuleFlagValue(ModuleOp mlirModule,
+ const llvm::Module *llvmModule,
+ llvm::MDTuple *mdTuple) {
+ unsigned profileNumEntries = mdTuple->getNumOperands();
+ if (profileNumEntries < 8) {
+ emitWarning(mlirModule.getLoc())
+ << "expected at 8 entries in 'ProfileSummary': "
+ << diagMD(mdTuple, llvmModule);
+ return nullptr;
+ }
+
+ unsigned summayIdx = 0;
+ auto checkOptionalPosition = [&](const llvm::MDOperand &md,
+ StringRef matchKey) -> LogicalResult {
+ // Make sure we won't step over the bound of the array of summary entries.
+ // Since (non-optional) DetailedSummary always comes last, the next entry in
+ // the tuple operand array must exist.
+ if (summayIdx + 1 >= profileNumEntries) {
+ emitWarning(mlirModule.getLoc())
+ << "the last summary entry is '" << matchKey
+ << "', expected 'DetailedSummary': " << diagMD(md, llvmModule);
+ return failure();
+ }
+
+ return success();
+ };
+
+ auto getOptIntValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+ std::optional<uint64_t> &val) -> LogicalResult {
+ if (!getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+ /*optional=*/true))
+ return success();
+ if (checkOptionalPosition(md, matchKey).failed())
+ return failure();
+ uint64_t tmpVal = 0;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule, md, matchKey,
+ tmpVal))
+ return failure();
+ val = tmpVal;
+ return success();
+ };
+
+ auto getOptDoubleValue = [&](const llvm::MDOperand &md, StringRef matchKey,
+ FloatAttr &attr) -> LogicalResult {
+ auto *valMD = getConstantMDFromKeyValueTuple(mlirModule, llvmModule, md,
+ matchKey, /*optional=*/true);
+ if (!valMD)
+ return success();
+ if (auto *cstFP = dyn_cast<llvm::ConstantFP>(valMD->getValue())) {
+ if (checkOptionalPosition(md, matchKey).failed())
+ return failure();
+ attr = FloatAttr::get(Float64Type::get(mlirModule.getContext()),
+ cstFP->getValueAPF());
+ return success();
+ }
+ emitWarning(mlirModule.getLoc())
+ << "expected double metadata value for key '" << matchKey
+ << "': " << diagMD(md, llvmModule);
+ return failure();
+ };
+
+ // Build ModuleFlagProfileSummaryAttr by sequentially fetching elements in
+ // a fixed order: format, total count, etc.
+ SmallVector<Attribute> profileSummary;
+ std::optional<ProfileSummaryFormatKind> format = convertProfileSummaryFormat(
+ mlirModule, llvmModule, mdTuple->getOperand(summayIdx++));
+ if (!format.has_value())
+ return nullptr;
+
+ uint64_t totalCount = 0, maxCount = 0, maxInternalCount = 0,
+ maxFunctionCount = 0, numCounts = 0, numFunctions = 0;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "TotalCount", totalCount))
+ return nullptr;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "MaxCount", maxCount))
+ return nullptr;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "MaxInternalCount", maxInternalCount))
+ return nullptr;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "MaxFunctionCount", maxFunctionCount))
+ return nullptr;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "NumCounts", numCounts))
+ return nullptr;
+ if (!convertInt64FromKeyValueTuple(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx++),
+ "NumFunctions", numFunctions))
+ return nullptr;
+
+ // Handle optional keys.
+ std::optional<uint64_t> isPartialProfile;
+ if (getOptIntValue(mdTuple->getOperand(summayIdx), "IsPartialProfile",
+ isPartialProfile)
+ .failed())
+ return nullptr;
+ if (isPartialProfile.has_value())
+ summayIdx++;
+
+ FloatAttr partialProfileRatio;
+ if (getOptDoubleValue(mdTuple->getOperand(summayIdx), "PartialProfileRatio",
+ partialProfileRatio)
+ .failed())
+ return nullptr;
+ if (partialProfileRatio)
+ summayIdx++;
+
+ // Handle detailed summary.
+ SmallVector<ModuleFlagProfileSummaryDetailedAttr> detailed;
+ if (!convertProfileSummaryDetailed(mlirModule, llvmModule,
+ mdTuple->getOperand(summayIdx), detailed))
+ return nullptr;
+
+ // Build the final profile summary attribute.
+ return ModuleFlagProfileSummaryAttr::get(
+ mlirModule->getContext(), *format, totalCount, maxCount, maxInternalCount,
+ maxFunctionCount, numCounts, numFunctions, isPartialProfile,
+ partialProfileRatio ? partialProfileRatio : nullptr, detailed);
----------------
gysit wrote:
```suggestion
partialProfileRatio, detailed);
```
nit: I think if it is nullptr it is nullptr ?
https://github.com/llvm/llvm-project/pull/138070
More information about the Mlir-commits
mailing list