[Mlir-commits] [mlir] [mlir][LLVM] Fix import of branch weights with "expected" field (PR #169776)

Vadim Curcă llvmlistbot at llvm.org
Thu Nov 27 00:17:42 PST 2025


https://github.com/VadimCurca updated https://github.com/llvm/llvm-project/pull/169776

>From 3478db19c1ad88ed15809ee7978bf2faba85c341 Mon Sep 17 00:00:00 2001
From: VadimCurca <vadim.curca14 at gmail.com>
Date: Thu, 27 Nov 2025 09:02:40 +0100
Subject: [PATCH] [mlir][llvm] Fix import of branch weights with "expected"
 field

This commit fixes the import of `branch_weights` metadata from LLVM IR
to the LLVM dialect. Previously, `branch_weights` metadata containing
the `!"expected"` field were rejected because the importer expected
integer weights at operand 1, but found a string.
---
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        | 23 +++++++++---
 .../LLVMIR/Import/metadata-profiling.ll       | 36 +++++++++++++++++++
 2 files changed, 54 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 44732d5466f6d..81c9da1d98c40 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -113,7 +113,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     return failure();
 
   // Handle function entry count metadata.
-  if (name->getString() == "function_entry_count") {
+  if (name->getString() == llvm::MDProfLabels::FunctionEntryCount) {
 
     // TODO support function entry count metadata with GUID fields.
     if (node->getNumOperands() != 2)
@@ -131,15 +131,28 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
            << "expected function_entry_count to be attached to a function";
   }
 
-  if (name->getString() != "branch_weights")
+  if (name->getString() != llvm::MDProfLabels::BranchWeights)
     return failure();
+  // The branch_weights metadata must have at least 2 operands.
+  if (node->getNumOperands() < 2)
+    return failure();
+
+  ArrayRef<llvm::MDOperand> branchWeightOperands =
+      node->operands().drop_front();
+  if (auto *mdString = dyn_cast<llvm::MDString>(node->getOperand(1))) {
+    if (mdString->getString() != llvm::MDProfLabels::ExpectedBranchWeights)
+      return failure();
+    // The MLIR WeightedBranchOpInterface does not support the
+    // ExpectedBranchWeights field, so it is dropped.
+    branchWeightOperands = branchWeightOperands.drop_front();
+  }
 
   // Handle branch weights metadata.
   SmallVector<int32_t> branchWeights;
-  branchWeights.reserve(node->getNumOperands() - 1);
-  for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
+  branchWeights.reserve(branchWeightOperands.size());
+  for (const llvm::MDOperand &operand : branchWeightOperands) {
     llvm::ConstantInt *branchWeight =
-        llvm::mdconst::dyn_extract<llvm::ConstantInt>(node->getOperand(i));
+        llvm::mdconst::dyn_extract<llvm::ConstantInt>(operand);
     if (!branchWeight)
       return failure();
     branchWeights.push_back(branchWeight->getZExtValue());
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index c623df0b605b2..328062545ed63 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -16,6 +16,22 @@ bb2:
 
 ; // -----
 
+; CHECK-LABEL: @cond_br_expected
+define i64 @cond_br_expected(i1 %arg1, i64 %arg2) {
+entry:
+  ; CHECK: llvm.cond_br
+  ; CHECK-SAME: weights([1, 2000])
+  br i1 %arg1, label %bb1, label %bb2, !prof !0
+bb1:
+  ret i64 %arg2
+bb2:
+  ret i64 %arg2
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 2000}
+
+; // -----
+
 ; CHECK-LABEL: @simple_switch(
 define i32 @simple_switch(i32 %arg1) {
   ; CHECK: llvm.switch
@@ -36,6 +52,26 @@ bbd:
 
 ; // -----
 
+; CHECK-LABEL: @simple_switch_expected(
+define i32 @simple_switch_expected(i32 %arg1) {
+  ; CHECK: llvm.switch
+  ; CHECK: {branch_weights = array<i32: 1, 1, 2000>}
+  switch i32 %arg1, label %bbd [
+    i32 0, label %bb1
+    i32 9, label %bb2
+  ], !prof !0
+bb1:
+  ret i32 %arg1
+bb2:
+  ret i32 %arg1
+bbd:
+  ret i32 %arg1
+}
+
+!0 = !{!"branch_weights", !"expected", i32 1, i32 1, i32 2000}
+
+; // -----
+
 ; Verify that a single weight attached to a call is not translated.
 ; The MLIR WeightedBranchOpInterface does not support this case.
 



More information about the Mlir-commits mailing list