[flang-commits] [flang] [mlir] [mlir][flang] Added Weighted[Region]BranchOpInterface's. (PR #142079)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Wed Jun 11 20:16:53 PDT 2025


https://github.com/vzakhari updated https://github.com/llvm/llvm-project/pull/142079

>From ff8a77618df58774eeed5af123778212049475f4 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 29 May 2025 19:09:16 -0700
Subject: [PATCH 1/7] [mlir][flang] Added Weighted[Region]BranchOpInterface's.

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.
---
 .../include/flang/Optimizer/Dialect/FIROps.td |  18 ++-
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  21 +++-
 .../Transforms/ControlFlowConverter.cpp       |   4 +-
 flang/test/Fir/cfg-conversion-if.fir          |  46 ++++++++
 flang/test/Fir/fir-ops.fir                    |  16 +++
 flang/test/Fir/invalid.fir                    |  37 ++++++
 .../Dialect/ControlFlow/IR/ControlFlowOps.td  |  34 +++---
 .../mlir/Interfaces/ControlFlowInterfaces.h   |  20 ++++
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 107 ++++++++++++++++++
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |   6 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp |  49 ++++++++
 .../Conversion/ControlFlowToLLVM/branch.mlir  |  14 +++
 mlir/test/Dialect/ControlFlow/invalid.mlir    |  36 ++++++
 mlir/test/Dialect/ControlFlow/ops.mlir        |  10 ++
 14 files changed, 396 insertions(+), 22 deletions(-)
 create mode 100644 flang/test/Fir/cfg-conversion-if.fir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 90e05ce3d5ca6..27a6ca4ebdb4e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
   }];
 }
 
-def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
-    "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
-    NoRegionArguments]> {
+def fir_IfOp
+    : region_Op<
+          "if", [DeclareOpInterfaceMethods<
+                     RegionBranchOpInterface, ["getRegionInvocationBounds",
+                                               "getEntrySuccessorRegions"]>,
+                 RecursiveMemoryEffects, NoRegionArguments,
+                 WeightedRegionBranchOpInterface]> {
   let summary = "if-then-else conditional operation";
   let description = [{
     Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
     ```
   }];
 
-  let arguments = (ins I1:$condition);
+  let arguments = (ins I1:$condition,
+      OptionalAttr<DenseI32ArrayAttr>:$region_weights);
   let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
 
     void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
                            unsigned resultNum);
+
+    /// Returns the display name string for the region_weights attribute.
+    static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
+      return "weights";
+    }
   }];
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 6181e1fad4240..ecfa2939e96a6 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
       parser.resolveOperand(cond, i1Type, result.operands))
     return mlir::failure();
 
+  if (mlir::succeeded(
+          parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
+    if (parser.parseLParen())
+      return mlir::failure();
+    mlir::DenseI32ArrayAttr weights;
+    if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
+      return mlir::failure();
+    if (weights)
+      result.addAttribute(getRegionWeightsAttrName(result.name), weights);
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
   if (parser.parseOptionalArrowTypeList(result.types))
     return mlir::failure();
 
@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
 void fir::IfOp::print(mlir::OpAsmPrinter &p) {
   bool printBlockTerminators = false;
   p << ' ' << getCondition();
+  if (auto weights = getRegionWeightsAttr()) {
+    p << ' ' << getWeightsAttrAssemblyName() << '(';
+    p.printStrippedAttrOrType(weights);
+    p << ')';
+  }
   if (!getResults().empty()) {
     p << " -> (" << getResultTypes() << ')';
     printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
     p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
                   printBlockTerminators);
   }
-  p.printOptionalAttrDict((*this)->getAttrs());
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elideAttrs=*/{getRegionWeightsAttrName()});
 }
 
 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 8a9e9b80134b8..5256ef8d53d85 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -212,9 +212,11 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
     }
 
     rewriter.setInsertionPointToEnd(condBlock);
-    rewriter.create<mlir::cf::CondBranchOp>(
+    auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
+    if (auto weights = ifOp.getRegionWeightsOrNull())
+      branchOp.setBranchWeights(weights);
     rewriter.replaceOp(ifOp, continueBlock->getArguments());
     return success();
   }
diff --git a/flang/test/Fir/cfg-conversion-if.fir b/flang/test/Fir/cfg-conversion-if.fir
new file mode 100644
index 0000000000000..1e30ee8e64f02
--- /dev/null
+++ b/flang/test/Fir/cfg-conversion-if.fir
@@ -0,0 +1,46 @@
+// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
+
+func.func private @callee() -> none
+
+// CHECK-LABEL:   func.func @if_then(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+// CHECK:           cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_0:.*]] = fir.call @callee() : () -> none
+// CHECK:           cf.br ^bb2
+// CHECK:         ^bb2:
+// CHECK:           return
+// CHECK:         }
+func.func @if_then(%cond: i1) {
+  fir.if %cond weights([10, 90]) {
+    fir.call @callee() : () -> none
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @if_then_else(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) -> i32 {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           cf.br ^bb3(%[[VAL_0]] : i32)
+// CHECK:         ^bb2:
+// CHECK:           cf.br ^bb3(%[[VAL_1]] : i32)
+// CHECK:         ^bb3(%[[VAL_2:.*]]: i32):
+// CHECK:           cf.br ^bb4
+// CHECK:         ^bb4:
+// CHECK:           return %[[VAL_2]] : i32
+// CHECK:         }
+func.func @if_then_else(%cond: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %result = fir.if %cond weights([90, 10]) -> i32 {
+    fir.result %c0 : i32
+  } else {
+    fir.result %c1 : i32
+  }
+  return %result : i32
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 9c444d2f4e0bc..3585bf9efca3e 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
   %6 = arith.addi %2, %5 : index
   return %6 : index
 }
+
+// CHECK-LABEL:   func.func @test_if_weights(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+func.func @test_if_weights(%cond: i1) {
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           }
+  fir.if %cond weights([99, 1]) {
+  }
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           } else {
+// CHECK:           }
+  fir.if %cond weights ([99,1]) {
+  } else {
+  }
+  return
+}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index 45cae1f82cb8e..facb49a01d067 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1393,3 +1393,40 @@ fir.local {type = local_init} @x.localizer : f32 init {
 ^bb0(%arg0: f32, %arg1: f32):
   fir.yield(%arg0 : f32)
 }
+
+// -----
+
+func.func @wrong_weights_number_in_if_then(%cond: i1) {
+// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
+  fir.if %cond weights([50]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
+// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
+  fir.if %cond weights([50, 40, 10]) {
+  } else {
+  }
+  return
+}
+
+// -----
+
+func.func @negative_weight_in_if_then(%cond: i1) {
+// expected-error @below {{weight #0 must be non-negative}}
+  fir.if %cond weights([-1, 101]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_total_weight_in_if_then(%cond: i1) {
+// expected-error @below {{total weight 101 is not 100}}
+  fir.if %cond weights([1, 100]) {
+  }
+  return
+}
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index 48f12b46a57f1..79da81ba049dd 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
 // CondBranchOp
 //===----------------------------------------------------------------------===//
 
-def CondBranchOp : CF_Op<"cond_br",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
-     Pure, Terminator]> {
+def CondBranchOp
+    : CF_Op<"cond_br", [AttrSizedOperandSegments,
+                        DeclareOpInterfaceMethods<
+                            BranchOpInterface, ["getSuccessorForOperands"]>,
+                        WeightedBranchOpInterface, Pure, Terminator]> {
   let summary = "Conditional branch operation";
   let description = [{
     The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
     ```
   }];
 
-  let arguments = (ins I1:$condition,
-                       Variadic<AnyType>:$trueDestOperands,
-                       Variadic<AnyType>:$falseDestOperands);
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
+      Variadic<AnyType>:$falseDestOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
 
-  let builders = [
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "ValueRange":$trueOperands, "Block *":$falseDest,
-      "ValueRange":$falseOperands), [{
-      build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
+  let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "ValueRange":$trueOperands,
+                                "Block *":$falseDest,
+                                "ValueRange":$falseOperands),
+                            [{
+      build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
             falseDest);
     }]>,
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
+                  OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "Block *":$falseDest,
+                                CArg<"ValueRange", "{}">:$falseOperands),
+                            [{
       build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
             falseOperands);
     }]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
 
   let hasCanonicalizer = 1;
   let assemblyFormat = [{
-    $condition `,`
+    $condition (`weights` `(` $branch_weights^ `)` )? `,`
     $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
     $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
     attr-dict
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 7f6967f11444f..d63800c12d132 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                                             const SuccessorOperands &operands);
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the branch weights attached to an operation
+/// implementing WeightedBranchOpInterface are correct.
+LogicalResult verifyBranchWeights(Operation *op);
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// WeightedRegiobBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the region weights attached to an operation
+/// implementing WeightedRegiobBranchOpInterface are correct.
+LogicalResult verifyRegionBranchWeights(Operation *op);
+} // namespace detail
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 69bce78e946c8..7a47b686ac7d1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -375,6 +375,113 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for branching terminator
+    operations, i.e. terminator operations with successors.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of successors
+    of the operation.
+
+    The weights specify the probability (in percents) of taking
+    a particular branch.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the branch weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getBranchWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getBranchWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the branch weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setBranchWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setBranchWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyBranchWeights($_op);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO: the probabilities of entering a particular region seem
+// to correlate with the values returned by
+// RegionBranchOpInterface::invocationBounds(), and we should probably
+// verify that the values are consistent. In that case, should
+// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
+def WeightedRegionBranchOpInterface
+    : OpInterface<"WeightedRegionBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for region operations
+    that exhibit branching behavior between held regions.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of regions
+    held by the operation (including empty regions).
+
+    The weights specify the probability (in percents) of branching
+    to a particular region when first executing the operation.
+    For example, for loop-like operations with a single region
+    the weight specifies the probability of entering the loop.
+    In this case, the weight must be either 0 or 100.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the region weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getRegionWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRegionWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the region weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setRegionWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setRegionWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyRegionBranchWeights($_op);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..12769e486a3c7 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -166,10 +166,14 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
                           TypeRange(adaptor.getFalseDestOperands()));
     if (failed(convertedFalseBlock))
       return failure();
-    Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
         op, adaptor.getCondition(), *convertedTrueBlock,
         adaptor.getTrueDestOperands(), *convertedFalseBlock,
         adaptor.getFalseDestOperands());
+    if (auto weights = op.getBranchWeightsOrNull()) {
+      newOp.setBranchWeights(weights);
+      op.removeBranchWeightsAttr();
+    }
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
     newOp->setAttrs(op->getAttrDictionary());
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2ae334b517a31..e587e8f1af178 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -80,6 +80,55 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyBranchWeights(Operation *op) {
+  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+  if (weights) {
+    if (weights.size() != op->getNumSuccessors())
+      return op->emitError() << "number of weights (" << weights.size()
+                             << ") does not match the number of successors ("
+                             << op->getNumSuccessors() << ")";
+    int32_t total = 0;
+    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+      if (weight.value() < 0)
+        return op->emitError()
+               << "weight #" << weight.index() << " must be non-negative";
+      total += weight.value();
+    }
+    if (total != 100)
+      return op->emitError() << "total weight " << total << " is not 100";
+  }
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
+  auto weights =
+      cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
+  if (weights) {
+    if (weights.size() != op->getNumRegions())
+      return op->emitError() << "number of weights (" << weights.size()
+                             << ") does not match the number of regions ("
+                             << op->getNumRegions() << ")";
+    int32_t total = 0;
+    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+      if (weight.value() < 0)
+        return op->emitError()
+               << "weight #" << weight.index() << " must be non-negative";
+      total += weight.value();
+    }
+    if (total != 100)
+      return op->emitError() << "total weight " << total << " is not 100";
+  }
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
index 9a0f2b7714544..7c78211d59010 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
@@ -67,3 +67,17 @@ func.func @unreachable_block() {
 ^bb1(%arg0: index):
   cf.br ^bb1(%arg0 : index)
 }
+
+// -----
+
+// Test case for cf.cond_br with weights.
+
+// CHECK-LABEL:   func.func @cf_cond_br_with_weights(
+func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index {
+// CHECK:           llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64)
+  cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array<i32: 90, 10>}
+^bb1(%arg1: index):
+  return %arg1 : index
+^bb2(%arg2: index):
+  return %arg2 : index
+}
diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir
index b51d8095c9974..6024c6d55ac64 100644
--- a/mlir/test/Dialect/ControlFlow/invalid.mlir
+++ b/mlir/test/Dialect/ControlFlow/invalid.mlir
@@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
   ^bb3(%bb3arg : i32):
     return
 }
+
+// -----
+
+// CHECK-LABEL: func @wrong_weights_number
+func.func @wrong_weights_number(%cond: i1) {
+  // expected-error at +1 {{number of weights (1) does not match the number of successors (2)}}
+  cf.cond_br %cond weights([100]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_weight
+func.func @wrong_total_weight(%cond: i1) {
+  // expected-error at +1 {{weight #0 must be non-negative}}
+  cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @wrong_total_weight
+func.func @wrong_total_weight(%cond: i1) {
+  // expected-error at +1 {{total weight 101 is not 100}}
+  cf.cond_br %cond weights([100, 1]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
diff --git a/mlir/test/Dialect/ControlFlow/ops.mlir b/mlir/test/Dialect/ControlFlow/ops.mlir
index c9317c7613972..160534240e0fa 100644
--- a/mlir/test/Dialect/ControlFlow/ops.mlir
+++ b/mlir/test/Dialect/ControlFlow/ops.mlir
@@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) {
   ^bb2:
     return
 }
+
+// CHECK-LABEL: func @cond_weights
+func.func @cond_weights(%cond: i1) {
+// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}}
+  cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}

>From c67ab473d0cdbd0903a2433b1b0c2c3b65ad5fb4 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 2 Jun 2025 16:57:05 -0700
Subject: [PATCH 2/7] Replaced LLVM dialect's BranchWeightOpInterface with
 WeightedBranchOpInterface.

---
 flang/test/Fir/invalid.fir                    | 13 +----
 .../mlir/Dialect/LLVMIR/LLVMInterfaces.td     | 36 -------------
 mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td   | 47 ++++++++--------
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 23 ++++----
 .../mlir/Target/LLVMIR/ModuleTranslation.h    |  2 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 53 +++++++++----------
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        |  2 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  4 +-
 mlir/test/Dialect/ControlFlow/invalid.mlir    | 14 +----
 9 files changed, 68 insertions(+), 126 deletions(-)

diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index facb49a01d067..aca0ecc1abdc1 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1397,7 +1397,7 @@ fir.local {type = local_init} @x.localizer : f32 init {
 // -----
 
 func.func @wrong_weights_number_in_if_then(%cond: i1) {
-// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
+// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
   fir.if %cond weights([50]) {
   }
   return
@@ -1406,7 +1406,7 @@ func.func @wrong_weights_number_in_if_then(%cond: i1) {
 // -----
 
 func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
-// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
+// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
   fir.if %cond weights([50, 40, 10]) {
   } else {
   }
@@ -1421,12 +1421,3 @@ func.func @negative_weight_in_if_then(%cond: i1) {
   }
   return
 }
-
-// -----
-
-func.func @wrong_total_weight_in_if_then(%cond: i1) {
-// expected-error @below {{total weight 101 is not 100}}
-  fir.if %cond weights([1, 100]) {
-  }
-  return
-}
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 2824f09dab6ce..138170f8c8762 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
   ];
 }
 
-def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
-  let description = [{
-    An interface for operations that can carry branch weights metadata. It
-    provides setters and getters for the operation's branch weights attribute.
-    The default implementation of the interface methods expect the operation to
-    have an attribute of type DenseI32ArrayAttr named branch_weights.
-  }];
-
-  let cppNamespace = "::mlir::LLVM";
-
-  let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns the branch weights attribute or nullptr",
-      /*returnType=*/  "::mlir::DenseI32ArrayAttr",
-      /*methodName=*/  "getBranchWeightsOrNull",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getBranchWeightsAttr();
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Sets the branch weights attribute",
-      /*returnType=*/  "void",
-      /*methodName=*/  "setBranchWeights",
-      /*args=*/        (ins "::mlir::DenseI32ArrayAttr":$attr),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        op.setBranchWeightsAttr(attr);
-      }]
-      >
-  ];
-}
-
 def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
   let description = [{
     An interface for memory operations that can carry access groups metadata.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index eda1d544cd81c..665f19b17eed5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
                                  LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
 
 // Call-related operations.
-def LLVM_InvokeOp : LLVM_Op<"invoke", [
-                      AttrSizedOperandSegments,
-                      DeclareOpInterfaceMethods<BranchOpInterface>,
-                      DeclareOpInterfaceMethods<CallOpInterface>,
-                      DeclareOpInterfaceMethods<BranchWeightOpInterface>,
-                      Terminator]> {
+def LLVM_InvokeOp
+    : LLVM_Op<"invoke", [AttrSizedOperandSegments,
+                         DeclareOpInterfaceMethods<BranchOpInterface>,
+                         DeclareOpInterfaceMethods<CallOpInterface>,
+                         DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+                         Terminator]> {
   let arguments = (ins
                    OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
                    OptionalAttr<FlatSymbolRefAttr>:$callee,
@@ -734,12 +734,13 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
 // CallOp
 //===----------------------------------------------------------------------===//
 
-def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
-                    [AttrSizedOperandSegments,
-                     DeclareOpInterfaceMethods<FastmathFlagsInterface>,
-                     DeclareOpInterfaceMethods<CallOpInterface>,
-                     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
-                     DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
+def LLVM_CallOp
+    : LLVM_MemAccessOpBase<
+          "call", [AttrSizedOperandSegments,
+                   DeclareOpInterfaceMethods<FastmathFlagsInterface>,
+                   DeclareOpInterfaceMethods<CallOpInterface>,
+                   DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+                   DeclareOpInterfaceMethods<WeightedBranchOpInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
     In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -1047,11 +1048,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
     LLVM_TerminatorPassthroughOpBuilder
   ];
 }
-def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface>,
-     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
-     Pure]> {
+def LLVM_CondBrOp
+    : LLVM_TerminatorOp<
+          "cond_br", [AttrSizedOperandSegments,
+                      DeclareOpInterfaceMethods<BranchOpInterface>,
+                      DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+                      Pure]> {
   let arguments = (ins I1:$condition,
                    Variadic<LLVM_Type>:$trueDestOperands,
                    Variadic<LLVM_Type>:$falseDestOperands,
@@ -1136,11 +1138,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
   }];
 }
 
-def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface>,
-     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
-     Pure]> {
+def LLVM_SwitchOp
+    : LLVM_TerminatorOp<
+          "switch", [AttrSizedOperandSegments,
+                     DeclareOpInterfaceMethods<BranchOpInterface>,
+                     DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
+                     Pure]> {
   let arguments = (ins
     AnySignlessInteger:$value,
     Variadic<AnyType>:$defaultOperands,
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 7a47b686ac7d1..1923df8308039 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -385,12 +385,12 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
     operations, i.e. terminator operations with successors.
 
     This interface provides methods for getting/setting integer non-negative
-    weight of each branch in the range from 0 to 100. The sum of weights
-    must be 100. The number of weights must match the number of successors
-    of the operation.
-
-    The weights specify the probability (in percents) of taking
-    a particular branch.
+    weight of each branch. The probability of executing a branch
+    is computed as the ratio between the branch's weight and the total
+    sum of the weights.
+    The number of weights must match the number of successors of the operation,
+    with one exception for CallOpInterface operations, which may only
+    have on weight when they do not have any successors.
 
     The default implementations of the methods expect the operation
     to have an attribute of type DenseI32ArrayAttr named branch_weights.
@@ -440,15 +440,16 @@ def WeightedRegionBranchOpInterface
     that exhibit branching behavior between held regions.
 
     This interface provides methods for getting/setting integer non-negative
-    weight of each branch in the range from 0 to 100. The sum of weights
-    must be 100. The number of weights must match the number of regions
+    weight of each branch. The probability of executing a region is computed
+    as the ratio between the region branch's weight and the total sum
+    of the weights.
+    The number of weights must match the number of regions
     held by the operation (including empty regions).
 
-    The weights specify the probability (in percents) of branching
-    to a particular region when first executing the operation.
+    The weights specify the probability of branching to a particular
+    region when first executing the operation.
     For example, for loop-like operations with a single region
     the weight specifies the probability of entering the loop.
-    In this case, the weight must be either 0 or 100.
 
     The default implementations of the methods expect the operation
     to have an attribute of type DenseI32ArrayAttr named branch_weights.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 97ae14aa0d6af..0f136c5c46d79 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -189,7 +189,7 @@ class ModuleTranslation {
                                   llvm::Instruction *inst);
 
   /// Sets LLVM profiling metadata for operations that have branch weights.
-  void setBranchWeightsMetadata(BranchWeightOpInterface op);
+  void setBranchWeightsMetadata(WeightedBranchOpInterface op);
 
   /// Sets LLVM loop metadata for branch operations that have a loop annotation
   /// attribute.
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index e587e8f1af178..2186925914f03 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,6 +9,7 @@
 #include <utility>
 
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -84,24 +85,33 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // WeightedBranchOpInterface
 //===----------------------------------------------------------------------===//
 
-LogicalResult detail::verifyBranchWeights(Operation *op) {
-  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
+                                   int64_t weightsNum,
+                                   llvm::StringRef weightAnchorName,
+                                   llvm::StringRef weightRefName) {
   if (weights) {
-    if (weights.size() != op->getNumSuccessors())
-      return op->emitError() << "number of weights (" << weights.size()
-                             << ") does not match the number of successors ("
-                             << op->getNumSuccessors() << ")";
-    int32_t total = 0;
-    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+    if (weights.size() != weightsNum)
+      return op->emitError() << "expects number of " << weightAnchorName
+                             << " weights to match number of " << weightRefName
+                             << ": " << weights.size() << " vs " << weightsNum;
+
+    for (auto weight : llvm::enumerate(weights.asArrayRef()))
       if (weight.value() < 0)
         return op->emitError()
                << "weight #" << weight.index() << " must be non-negative";
-      total += weight.value();
-    }
-    if (total != 100)
-      return op->emitError() << "total weight " << total << " is not 100";
   }
-  return mlir::success();
+  return success();
+}
+
+LogicalResult detail::verifyBranchWeights(Operation *op) {
+  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+  unsigned successorsNum = op->getNumSuccessors();
+  // CallOpInterface operations without successors may only have
+  // one weight, though it seems to be redundant and indicate
+  // 100% probability of calling the callee(s).
+  int64_t weightsNum =
+      (successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
+  return verifyWeights(op, weights, weightsNum, "branch", "successors");
 }
 
 //===----------------------------------------------------------------------===//
@@ -111,22 +121,7 @@ LogicalResult detail::verifyBranchWeights(Operation *op) {
 LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
   auto weights =
       cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
-  if (weights) {
-    if (weights.size() != op->getNumRegions())
-      return op->emitError() << "number of weights (" << weights.size()
-                             << ") does not match the number of regions ("
-                             << op->getNumRegions() << ")";
-    int32_t total = 0;
-    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
-      if (weight.value() < 0)
-        return op->emitError()
-               << "weight #" << weight.index() << " must be non-negative";
-      total += weight.value();
-    }
-    if (total != 100)
-      return op->emitError() << "total weight " << total << " is not 100";
-  }
-  return mlir::success();
+  return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 1b5ce868b5c77..89045b4a16469 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -146,7 +146,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     branchWeights.push_back(branchWeight->getZExtValue());
   }
 
-  if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
+  if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
     iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
     return success();
   }
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2702b7aa544da..c4c427ecac091 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
       return failure();
 
     // Set the branch weight metadata on the translated instruction.
-    if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
+    if (auto iface = dyn_cast<WeightedBranchOpInterface>(op))
       setBranchWeightsMetadata(iface);
   }
 
@@ -2026,7 +2026,7 @@ void ModuleTranslation::setDereferenceableMetadata(
   inst->setMetadata(kindId, derefSizeNode);
 }
 
-void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
+void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
   DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
   if (!weightsAttr)
     return;
diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir
index 6024c6d55ac64..f1973cd4e7931 100644
--- a/mlir/test/Dialect/ControlFlow/invalid.mlir
+++ b/mlir/test/Dialect/ControlFlow/invalid.mlir
@@ -72,7 +72,7 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
 
 // CHECK-LABEL: func @wrong_weights_number
 func.func @wrong_weights_number(%cond: i1) {
-  // expected-error at +1 {{number of weights (1) does not match the number of successors (2)}}
+  // expected-error at +1 {{expects number of branch weights to match number of successors: 1 vs 2}}
   cf.cond_br %cond weights([100]), ^bb1, ^bb2
   ^bb1:
     return
@@ -91,15 +91,3 @@ func.func @wrong_total_weight(%cond: i1) {
   ^bb2:
     return
 }
-
-// -----
-
-// CHECK-LABEL: func @wrong_total_weight
-func.func @wrong_total_weight(%cond: i1) {
-  // expected-error at +1 {{total weight 101 is not 100}}
-  cf.cond_br %cond weights([100, 1]), ^bb1, ^bb2
-  ^bb1:
-    return
-  ^bb2:
-    return
-}

>From 1e03d06d11d04e7c4e1893c993ebf678323891a8 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Mon, 2 Jun 2025 17:07:09 -0700
Subject: [PATCH 3/7] Removed redundant verification code.

---
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index c7528c970a4ba..a1b1455336007 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -589,10 +589,6 @@ LogicalResult SwitchOp::verify() {
            static_cast<int64_t>(getCaseDestinations().size())))
     return emitOpError("expects number of case values to match number of "
                        "case destinations");
-  if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
-    return emitError("expects number of branch weights to match number of "
-                     "successors: ")
-           << getBranchWeights()->size() << " vs " << getNumSuccessors();
   if (getCaseValues() &&
       getValue().getType() != getCaseValues()->getElementType())
     return emitError("expects case value type to match condition value type");

>From 4ab5cacf09be2a8ca72c003dd277716e1d4d2bf0 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 4 Jun 2025 14:54:05 -0700
Subject: [PATCH 4/7] Addresses review comments.

---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 28 +++++++++++--------
 1 file changed, 16 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2186925914f03..5b9b166cc4650 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -86,20 +86,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
-                                   int64_t weightsNum,
+                                   int64_t expectedWeightsNum,
                                    llvm::StringRef weightAnchorName,
                                    llvm::StringRef weightRefName) {
-  if (weights) {
-    if (weights.size() != weightsNum)
-      return op->emitError() << "expects number of " << weightAnchorName
-                             << " weights to match number of " << weightRefName
-                             << ": " << weights.size() << " vs " << weightsNum;
-
-    for (auto weight : llvm::enumerate(weights.asArrayRef()))
-      if (weight.value() < 0)
-        return op->emitError()
-               << "weight #" << weight.index() << " must be non-negative";
-  }
+  if (!weights)
+    return success();
+
+  if (weights.size() != expectedWeightsNum)
+    return op->emitError() << "expects number of " << weightAnchorName
+                           << " weights to match number of " << weightRefName
+                           << ": " << weights.size() << " vs "
+                           << expectedWeightsNum;
+
+  for (auto [index, weight] : llvm::enumerate(weights.asArrayRef()))
+    if (weight < 0)
+      return op->emitError() << "weight #" << index << " must be non-negative";
+
   return success();
 }
 
@@ -109,6 +111,8 @@ LogicalResult detail::verifyBranchWeights(Operation *op) {
   // CallOpInterface operations without successors may only have
   // one weight, though it seems to be redundant and indicate
   // 100% probability of calling the callee(s).
+  // TODO: maybe we should remove this interface for calls without
+  // successors.
   int64_t weightsNum =
       (successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
   return verifyWeights(op, weights, weightsNum, "branch", "successors");

>From 6963e7d98369e99fe42a4b96839465f54ddf6cc3 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 11 Jun 2025 11:47:23 -0700
Subject: [PATCH 5/7] Addressed review comments.

---
 .../Transforms/ControlFlowConverter.cpp       |  4 +-
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 46 +++++++++++--------
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |  5 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 21 +++++----
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        |  2 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  |  7 +--
 6 files changed, 47 insertions(+), 38 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 5256ef8d53d85..2f65c42365645 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -215,8 +215,8 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
     auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
-    if (auto weights = ifOp.getRegionWeightsOrNull())
-      branchOp.setBranchWeights(weights);
+    if (auto weights = ifOp.getWeights())
+      branchOp.setWeights(weights);
     rewriter.replaceOp(ifOp, continueBlock->getArguments());
     return success();
   }
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 1923df8308039..ab95de8e4d0fe 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -388,9 +388,10 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
     weight of each branch. The probability of executing a branch
     is computed as the ratio between the branch's weight and the total
     sum of the weights.
-    The number of weights must match the number of successors of the operation,
+    The weights are optional. If they are provided, then their number
+    must match the number of successors of the operation,
     with one exception for CallOpInterface operations, which may only
-    have on weight when they do not have any successors.
+    have one weight when they do not have any successors.
 
     The default implementations of the methods expect the operation
     to have an attribute of type DenseI32ArrayAttr named branch_weights.
@@ -398,24 +399,26 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
   let cppNamespace = "::mlir";
 
   let methods = [InterfaceMethod<
-                     /*desc=*/"Returns the branch weights attribute or nullptr",
-                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
-                     /*methodName=*/"getBranchWeightsOrNull",
+                     /*desc=*/"Returns the branch weights",
+                     /*returnType=*/"::llvm::ArrayRef<int32_t>",
+                     /*methodName=*/"getWeights",
                      /*args=*/(ins),
                      /*methodBody=*/[{}],
                      /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getBranchWeightsAttr();
+        if (auto attr = op.getBranchWeightsAttr())
+          return attr.asArrayRef();
+        return {};
       }]>,
                  InterfaceMethod<
-                     /*desc=*/"Sets the branch weights attribute",
+                     /*desc=*/"Sets the branch weights",
                      /*returnType=*/"void",
-                     /*methodName=*/"setBranchWeights",
-                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodName=*/"setWeights",
+                     /*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
                      /*methodBody=*/[{}],
                      /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
-        op.setBranchWeightsAttr(attr);
+        op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
       }]>,
   ];
 
@@ -443,8 +446,9 @@ def WeightedRegionBranchOpInterface
     weight of each branch. The probability of executing a region is computed
     as the ratio between the region branch's weight and the total sum
     of the weights.
-    The number of weights must match the number of regions
-    held by the operation (including empty regions).
+    The weights are optional. If they are provided, then their number
+    must match the number of regions held by the operation
+    (including empty regions).
 
     The weights specify the probability of branching to a particular
     region when first executing the operation.
@@ -457,24 +461,26 @@ def WeightedRegionBranchOpInterface
   let cppNamespace = "::mlir";
 
   let methods = [InterfaceMethod<
-                     /*desc=*/"Returns the region weights attribute or nullptr",
-                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
-                     /*methodName=*/"getRegionWeightsOrNull",
+                     /*desc=*/"Returns the region weights",
+                     /*returnType=*/"::llvm::ArrayRef<int32_t>",
+                     /*methodName=*/"getWeights",
                      /*args=*/(ins),
                      /*methodBody=*/[{}],
                      /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getRegionWeightsAttr();
+        if (auto attr = op.getRegionWeightsAttr())
+          return attr.asArrayRef();
+        return {};
       }]>,
                  InterfaceMethod<
-                     /*desc=*/"Sets the region weights attribute",
+                     /*desc=*/"Sets the region weights",
                      /*returnType=*/"void",
-                     /*methodName=*/"setRegionWeights",
-                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodName=*/"setWeights",
+                     /*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
                      /*methodBody=*/[{}],
                      /*defaultImpl=*/[{
         auto op = cast<ConcreteOp>(this->getOperation());
-        op.setRegionWeightsAttr(attr);
+        op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
       }]>,
   ];
 
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 12769e486a3c7..d31d7d801e149 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -170,8 +170,9 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
         op, adaptor.getCondition(), *convertedTrueBlock,
         adaptor.getTrueDestOperands(), *convertedFalseBlock,
         adaptor.getFalseDestOperands());
-    if (auto weights = op.getBranchWeightsOrNull()) {
-      newOp.setBranchWeights(weights);
+    ArrayRef<int32_t> weights = op.getWeights();
+    if (!weights.empty()) {
+      newOp.setWeights(weights);
       op.removeBranchWeightsAttr();
     }
     // TODO: We should not just forward all attributes like that. But there are
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 5b9b166cc4650..8db0b34099eba 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -85,11 +85,12 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
 // WeightedBranchOpInterface
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
-                                   int64_t expectedWeightsNum,
+static LogicalResult verifyWeights(Operation *op,
+                                   llvm::ArrayRef<int32_t> weights,
+                                   std::size_t expectedWeightsNum,
                                    llvm::StringRef weightAnchorName,
                                    llvm::StringRef weightRefName) {
-  if (!weights)
+  if (weights.empty())
     return success();
 
   if (weights.size() != expectedWeightsNum)
@@ -98,7 +99,7 @@ static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
                            << ": " << weights.size() << " vs "
                            << expectedWeightsNum;
 
-  for (auto [index, weight] : llvm::enumerate(weights.asArrayRef()))
+  for (auto [index, weight] : llvm::enumerate(weights))
     if (weight < 0)
       return op->emitError() << "weight #" << index << " must be non-negative";
 
@@ -106,14 +107,14 @@ static LogicalResult verifyWeights(Operation *op, DenseI32ArrayAttr weights,
 }
 
 LogicalResult detail::verifyBranchWeights(Operation *op) {
-  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+  llvm::ArrayRef<int32_t> weights =
+      cast<WeightedBranchOpInterface>(op).getWeights();
   unsigned successorsNum = op->getNumSuccessors();
   // CallOpInterface operations without successors may only have
   // one weight, though it seems to be redundant and indicate
   // 100% probability of calling the callee(s).
-  // TODO: maybe we should remove this interface for calls without
-  // successors.
-  int64_t weightsNum =
+  // TODO: maybe we should disallow weights for calls without successors.
+  std::size_t weightsNum =
       (successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
   return verifyWeights(op, weights, weightsNum, "branch", "successors");
 }
@@ -123,8 +124,8 @@ LogicalResult detail::verifyBranchWeights(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
-  auto weights =
-      cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
+  llvm::ArrayRef<int32_t> weights =
+      cast<WeightedRegionBranchOpInterface>(op).getWeights();
   return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 89045b4a16469..5fac8867a8545 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -147,7 +147,7 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
   }
 
   if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
-    iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
+    iface.setWeights(branchWeights);
     return success();
   }
   return failure();
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index c4c427ecac091..b5b1ea1b2244a 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -2027,13 +2027,14 @@ void ModuleTranslation::setDereferenceableMetadata(
 }
 
 void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
-  DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
-  if (!weightsAttr)
+  SmallVector<uint32_t> weights;
+  llvm::transform(op.getWeights(), std::back_inserter(weights),
+                  [](int32_t value) { return static_cast<uint32_t>(value); });
+  if (weights.empty())
     return;
 
   llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
   assert(inst && "expected the operation to have a mapping to an instruction");
-  SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
   inst->setMetadata(
       llvm::LLVMContext::MD_prof,
       llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));

>From 1caee0a03da57d5cb9e7e574c38446b3abf6f0a8 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 11 Jun 2025 16:54:24 -0700
Subject: [PATCH 6/7] Fixed flang build.

---
 flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 2f65c42365645..3d35803e6a2d3 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -215,7 +215,8 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
     auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
-    if (auto weights = ifOp.getWeights())
+    llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
+    if (!weights.empty())
       branchOp.setWeights(weights);
     rewriter.replaceOp(ifOp, continueBlock->getArguments());
     return success();

>From dc02ac303bb13b5eaf3132d19d8b186ececdd72e Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 11 Jun 2025 19:57:50 -0700
Subject: [PATCH 7/7] Strengthened constraint on the number of weights for
 calls.

---
 .../mlir/Interfaces/ControlFlowInterfaces.td  |  8 +++---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 13 ++++------
 .../LLVMIR/LLVMIRToLLVMTranslation.cpp        |  9 ++++++-
 mlir/test/Dialect/ControlFlow/invalid.mlir    | 12 +++++++++
 .../LLVMIR/Import/metadata-profiling.ll       | 13 ++++++----
 mlir/test/Target/LLVMIR/llvmir-invalid.mlir   | 20 ++++++++++++++
 mlir/test/Target/LLVMIR/llvmir.mlir           | 26 -------------------
 7 files changed, 56 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index ab95de8e4d0fe..46ab0b9ebbc6b 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -387,11 +387,9 @@ def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
     This interface provides methods for getting/setting integer non-negative
     weight of each branch. The probability of executing a branch
     is computed as the ratio between the branch's weight and the total
-    sum of the weights.
+    sum of the weights (which cannot be zero).
     The weights are optional. If they are provided, then their number
-    must match the number of successors of the operation,
-    with one exception for CallOpInterface operations, which may only
-    have one weight when they do not have any successors.
+    must match the number of successors of the operation.
 
     The default implementations of the methods expect the operation
     to have an attribute of type DenseI32ArrayAttr named branch_weights.
@@ -445,7 +443,7 @@ def WeightedRegionBranchOpInterface
     This interface provides methods for getting/setting integer non-negative
     weight of each branch. The probability of executing a region is computed
     as the ratio between the region branch's weight and the total sum
-    of the weights.
+    of the weights (which cannot be zero).
     The weights are optional. If they are provided, then their number
     must match the number of regions held by the operation
     (including empty regions).
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 8db0b34099eba..3a63db35eec0f 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -103,20 +103,17 @@ static LogicalResult verifyWeights(Operation *op,
     if (weight < 0)
       return op->emitError() << "weight #" << index << " must be non-negative";
 
+  if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
+    return op->emitError() << "branch weights cannot all be zero";
+
   return success();
 }
 
 LogicalResult detail::verifyBranchWeights(Operation *op) {
   llvm::ArrayRef<int32_t> weights =
       cast<WeightedBranchOpInterface>(op).getWeights();
-  unsigned successorsNum = op->getNumSuccessors();
-  // CallOpInterface operations without successors may only have
-  // one weight, though it seems to be redundant and indicate
-  // 100% probability of calling the callee(s).
-  // TODO: maybe we should disallow weights for calls without successors.
-  std::size_t weightsNum =
-      (successorsNum == 0 && isa<CallOpInterface>(op)) ? 1 : successorsNum;
-  return verifyWeights(op, weights, weightsNum, "branch", "successors");
+  return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
+                       "successors");
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
index 5fac8867a8545..e67aa892afe09 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
@@ -147,7 +147,14 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
   }
 
   if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
-    iface.setWeights(branchWeights);
+    // LLVM allows attaching a single weight to call instructions.
+    // This is used for carrying the execution count information
+    // in PGO modes. MLIR WeightedBranchOpInterface does not allow this,
+    // so we drop the metadata in this case.
+    // LLVM should probably use the VP form of MD_prof metadata
+    // for such cases.
+    if (op->getNumSuccessors() != 0)
+      iface.setWeights(branchWeights);
     return success();
   }
   return failure();
diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir
index f1973cd4e7931..1b8de22a9ff9f 100644
--- a/mlir/test/Dialect/ControlFlow/invalid.mlir
+++ b/mlir/test/Dialect/ControlFlow/invalid.mlir
@@ -91,3 +91,15 @@ func.func @wrong_total_weight(%cond: i1) {
   ^bb2:
     return
 }
+
+// -----
+
+// CHECK-LABEL: func @zero_weights
+func.func @wrong_total_weight(%cond: i1) {
+  // expected-error at +1 {{branch weights cannot all be zero}}
+  cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
index cc3b47a54dfe9..c623df0b605b2 100644
--- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
+++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
@@ -36,14 +36,17 @@ bbd:
 
 ; // -----
 
+; Verify that a single weight attached to a call is not translated.
+; The MLIR WeightedBranchOpInterface does not support this case.
+
 ; CHECK: llvm.func @fn()
-declare void @fn()
+declare i32 @fn()
 
 ; CHECK-LABEL: @call_branch_weights
-define void @call_branch_weights() {
-  ; CHECK:  llvm.call @fn() {branch_weights = array<i32: 42>}
-  call void @fn(), !prof !0
-  ret void
+define i32 @call_branch_weights() {
+  ; CHECK:  llvm.call @fn() : () -> i32
+  %1 = call i32 @fn(), !prof !0
+  ret i32 %1
 }
 
 !0 = !{!"branch_weights", i32 42}
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 24a7b42557278..660520b948d6c 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -448,3 +448,23 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} :
 }
 
 llvm.func extern_weak @extern_func()
+
+// -----
+
+llvm.func @fn()
+
+llvm.func @call_branch_weights() {
+  // expected-error @below{{expects number of branch weights to match number of successors: 1 vs 0}}
+  llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
+  llvm.return
+}
+
+// -----
+
+llvm.func @fn() -> i32
+
+llvm.func @call_branch_weights() {
+  // expected-error @below{{expects number of branch weights to match number of successors: 1 vs 0}}
+  %res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 7742259e7a478..fc1993b50ba2d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32,  %arg1 : i32) -> i32 {
 
 // -----
 
-llvm.func @fn()
-
-// CHECK-LABEL: @call_branch_weights
-llvm.func @call_branch_weights() {
-  // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
-  llvm.return
-}
-
-// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
-
-// -----
-
-llvm.func @fn() -> i32
-
-// CHECK-LABEL: @call_branch_weights
-llvm.func @call_branch_weights() {
-  // CHECK: !prof ![[NODE:[0-9]+]]
-  %res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
-  llvm.return
-}
-
-// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
-
-// -----
-
 llvm.func @foo()
 llvm.func @__gxx_personality_v0(...) -> i32
 



More information about the flang-commits mailing list