[flang-commits] [flang] [mlir] [mlir][flang] Added Weighted[Region]BranchOpInterface's. (PR #142079)

Slava Zakharin via flang-commits flang-commits at lists.llvm.org
Thu May 29 19:48:31 PDT 2025


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/142079

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.


>From f258ed9be16829b6d5c9261c1a0b153c697271e7 Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Thu, 29 May 2025 19:09:16 -0700
Subject: [PATCH] [mlir][flang] Added Weighted[Region]BranchOpInterface's.

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.
---
 .../include/flang/Optimizer/Dialect/FIROps.td |  18 ++-
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  21 +++-
 .../Transforms/ControlFlowConverter.cpp       |   4 +-
 flang/test/Fir/cfg-conversion-if.fir          |  46 ++++++++
 flang/test/Fir/fir-ops.fir                    |  16 +++
 flang/test/Fir/invalid.fir                    |  37 ++++++
 .../Dialect/ControlFlow/IR/ControlFlowOps.td  |  34 +++---
 .../mlir/Interfaces/ControlFlowInterfaces.h   |  20 ++++
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 107 ++++++++++++++++++
 .../ControlFlowToLLVM/ControlFlowToLLVM.cpp   |   6 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp |  49 ++++++++
 .../Conversion/ControlFlowToLLVM/branch.mlir  |  14 +++
 mlir/test/Dialect/ControlFlow/invalid.mlir    |  36 ++++++
 mlir/test/Dialect/ControlFlow/ops.mlir        |  10 ++
 14 files changed, 396 insertions(+), 22 deletions(-)
 create mode 100644 flang/test/Fir/cfg-conversion-if.fir

diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f4b17ef7eed09..7001e25a9bcda 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
   }];
 }
 
-def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
-    "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
-    NoRegionArguments]> {
+def fir_IfOp
+    : region_Op<
+          "if", [DeclareOpInterfaceMethods<
+                     RegionBranchOpInterface, ["getRegionInvocationBounds",
+                                               "getEntrySuccessorRegions"]>,
+                 RecursiveMemoryEffects, NoRegionArguments,
+                 WeightedRegionBranchOpInterface]> {
   let summary = "if-then-else conditional operation";
   let description = [{
     Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
     ```
   }];
 
-  let arguments = (ins I1:$condition);
+  let arguments = (ins I1:$condition,
+      OptionalAttr<DenseI32ArrayAttr>:$region_weights);
   let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
 
     void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
                            unsigned resultNum);
+
+    /// Returns the display name string for the region_weights attribute.
+    static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
+      return "weights";
+    }
   }];
 }
 
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index cbe93907265f6..2949120894132 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
       parser.resolveOperand(cond, i1Type, result.operands))
     return mlir::failure();
 
+  if (mlir::succeeded(
+          parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
+    if (parser.parseLParen())
+      return mlir::failure();
+    mlir::DenseI32ArrayAttr weights;
+    if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
+      return mlir::failure();
+    if (weights)
+      result.addAttribute(getRegionWeightsAttrName(result.name), weights);
+    if (parser.parseRParen())
+      return mlir::failure();
+  }
+
   if (parser.parseOptionalArrowTypeList(result.types))
     return mlir::failure();
 
@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
 void fir::IfOp::print(mlir::OpAsmPrinter &p) {
   bool printBlockTerminators = false;
   p << ' ' << getCondition();
+  if (auto weights = getRegionWeightsAttr()) {
+    p << ' ' << getWeightsAttrAssemblyName() << '(';
+    p.printStrippedAttrOrType(weights);
+    p << ')';
+  }
   if (!getResults().empty()) {
     p << " -> (" << getResultTypes() << ')';
     printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
     p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
                   printBlockTerminators);
   }
-  p.printOptionalAttrDict((*this)->getAttrs());
+  p.printOptionalAttrDict((*this)->getAttrs(),
+                          /*elideAttrs=*/{getRegionWeightsAttrName()});
 }
 
 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
index 8a9e9b80134b8..5256ef8d53d85 100644
--- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
+++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
@@ -212,9 +212,11 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
     }
 
     rewriter.setInsertionPointToEnd(condBlock);
-    rewriter.create<mlir::cf::CondBranchOp>(
+    auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
         loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
         otherwiseBlock, llvm::ArrayRef<mlir::Value>());
+    if (auto weights = ifOp.getRegionWeightsOrNull())
+      branchOp.setBranchWeights(weights);
     rewriter.replaceOp(ifOp, continueBlock->getArguments());
     return success();
   }
diff --git a/flang/test/Fir/cfg-conversion-if.fir b/flang/test/Fir/cfg-conversion-if.fir
new file mode 100644
index 0000000000000..1e30ee8e64f02
--- /dev/null
+++ b/flang/test/Fir/cfg-conversion-if.fir
@@ -0,0 +1,46 @@
+// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
+
+func.func private @callee() -> none
+
+// CHECK-LABEL:   func.func @if_then(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+// CHECK:           cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           %[[VAL_0:.*]] = fir.call @callee() : () -> none
+// CHECK:           cf.br ^bb2
+// CHECK:         ^bb2:
+// CHECK:           return
+// CHECK:         }
+func.func @if_then(%cond: i1) {
+  fir.if %cond weights([10, 90]) {
+    fir.call @callee() : () -> none
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @if_then_else(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) -> i32 {
+// CHECK:           %[[VAL_0:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_1:.*]] = arith.constant 1 : i32
+// CHECK:           cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           cf.br ^bb3(%[[VAL_0]] : i32)
+// CHECK:         ^bb2:
+// CHECK:           cf.br ^bb3(%[[VAL_1]] : i32)
+// CHECK:         ^bb3(%[[VAL_2:.*]]: i32):
+// CHECK:           cf.br ^bb4
+// CHECK:         ^bb4:
+// CHECK:           return %[[VAL_2]] : i32
+// CHECK:         }
+func.func @if_then_else(%cond: i1) -> i32 {
+  %c0 = arith.constant 0 : i32
+  %c1 = arith.constant 1 : i32
+  %result = fir.if %cond weights([90, 10]) -> i32 {
+    fir.result %c0 : i32
+  } else {
+    fir.result %c1 : i32
+  }
+  return %result : i32
+}
diff --git a/flang/test/Fir/fir-ops.fir b/flang/test/Fir/fir-ops.fir
index 9c444d2f4e0bc..3585bf9efca3e 100644
--- a/flang/test/Fir/fir-ops.fir
+++ b/flang/test/Fir/fir-ops.fir
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
   %6 = arith.addi %2, %5 : index
   return %6 : index
 }
+
+// CHECK-LABEL:   func.func @test_if_weights(
+// CHECK-SAME:      %[[ARG0:.*]]: i1) {
+func.func @test_if_weights(%cond: i1) {
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           }
+  fir.if %cond weights([99, 1]) {
+  }
+// CHECK:           fir.if %[[ARG0]] weights([99, 1]) {
+// CHECK:           } else {
+// CHECK:           }
+  fir.if %cond weights ([99,1]) {
+  } else {
+  }
+  return
+}
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index fd607fd9066f7..0391cdbef71e5 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1385,3 +1385,40 @@ fir.local {type = local_init} @x.localizer : f32 init {
 ^bb0(%arg0: f32, %arg1: f32):
   fir.yield(%arg0 : f32)
 }
+
+// -----
+
+func.func @wrong_weights_number_in_if_then(%cond: i1) {
+// expected-error @below {{number of weights (1) does not match the number of regions (2)}}
+  fir.if %cond weights([50]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
+// expected-error @below {{number of weights (3) does not match the number of regions (2)}}
+  fir.if %cond weights([50, 40, 10]) {
+  } else {
+  }
+  return
+}
+
+// -----
+
+func.func @negative_weight_in_if_then(%cond: i1) {
+// expected-error @below {{weight #0 must be non-negative}}
+  fir.if %cond weights([-1, 101]) {
+  }
+  return
+}
+
+// -----
+
+func.func @wrong_total_weight_in_if_then(%cond: i1) {
+// expected-error @below {{total weight 101 is not 100}}
+  fir.if %cond weights([1, 100]) {
+  }
+  return
+}
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index 48f12b46a57f1..79da81ba049dd 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
 // CondBranchOp
 //===----------------------------------------------------------------------===//
 
-def CondBranchOp : CF_Op<"cond_br",
-    [AttrSizedOperandSegments,
-     DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
-     Pure, Terminator]> {
+def CondBranchOp
+    : CF_Op<"cond_br", [AttrSizedOperandSegments,
+                        DeclareOpInterfaceMethods<
+                            BranchOpInterface, ["getSuccessorForOperands"]>,
+                        WeightedBranchOpInterface, Pure, Terminator]> {
   let summary = "Conditional branch operation";
   let description = [{
     The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
     ```
   }];
 
-  let arguments = (ins I1:$condition,
-                       Variadic<AnyType>:$trueDestOperands,
-                       Variadic<AnyType>:$falseDestOperands);
+  let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
+      Variadic<AnyType>:$falseDestOperands,
+      OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
 
-  let builders = [
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "ValueRange":$trueOperands, "Block *":$falseDest,
-      "ValueRange":$falseOperands), [{
-      build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
+  let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "ValueRange":$trueOperands,
+                                "Block *":$falseDest,
+                                "ValueRange":$falseOperands),
+                            [{
+      build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
             falseDest);
     }]>,
-    OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
-      "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
+                  OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
+                                "Block *":$falseDest,
+                                CArg<"ValueRange", "{}">:$falseOperands),
+                            [{
       build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
             falseOperands);
     }]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
 
   let hasCanonicalizer = 1;
   let assemblyFormat = [{
-    $condition `,`
+    $condition (`weights` `(` $branch_weights^ `)` )? `,`
     $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
     $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
     attr-dict
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 7f6967f11444f..d63800c12d132 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
                                             const SuccessorOperands &operands);
 } // namespace detail
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the branch weights attached to an operation
+/// implementing WeightedBranchOpInterface are correct.
+LogicalResult verifyBranchWeights(Operation *op);
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// WeightedRegiobBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// Verify that the region weights attached to an operation
+/// implementing WeightedRegiobBranchOpInterface are correct.
+LogicalResult verifyRegionBranchWeights(Operation *op);
+} // namespace detail
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 69bce78e946c8..7a47b686ac7d1 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -375,6 +375,113 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for branching terminator
+    operations, i.e. terminator operations with successors.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of successors
+    of the operation.
+
+    The weights specify the probability (in percents) of taking
+    a particular branch.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the branch weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getBranchWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getBranchWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the branch weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setBranchWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setBranchWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyBranchWeights($_op);
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+// TODO: the probabilities of entering a particular region seem
+// to correlate with the values returned by
+// RegionBranchOpInterface::invocationBounds(), and we should probably
+// verify that the values are consistent. In that case, should
+// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
+def WeightedRegionBranchOpInterface
+    : OpInterface<"WeightedRegionBranchOpInterface"> {
+  let description = [{
+    This interface provides weight information for region operations
+    that exhibit branching behavior between held regions.
+
+    This interface provides methods for getting/setting integer non-negative
+    weight of each branch in the range from 0 to 100. The sum of weights
+    must be 100. The number of weights must match the number of regions
+    held by the operation (including empty regions).
+
+    The weights specify the probability (in percents) of branching
+    to a particular region when first executing the operation.
+    For example, for loop-like operations with a single region
+    the weight specifies the probability of entering the loop.
+    In this case, the weight must be either 0 or 100.
+
+    The default implementations of the methods expect the operation
+    to have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [InterfaceMethod<
+                     /*desc=*/"Returns the region weights attribute or nullptr",
+                     /*returnType=*/"::mlir::DenseI32ArrayAttr",
+                     /*methodName=*/"getRegionWeightsOrNull",
+                     /*args=*/(ins),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getRegionWeightsAttr();
+      }]>,
+                 InterfaceMethod<
+                     /*desc=*/"Sets the region weights attribute",
+                     /*returnType=*/"void",
+                     /*methodName=*/"setRegionWeights",
+                     /*args=*/(ins "::mlir::DenseI32ArrayAttr":$attr),
+                     /*methodBody=*/[{}],
+                     /*defaultImpl=*/[{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setRegionWeightsAttr(attr);
+      }]>,
+  ];
+
+  let verify = [{
+    return ::mlir::detail::verifyRegionBranchWeights($_op);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index debfd003bd5b5..12769e486a3c7 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -166,10 +166,14 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
                           TypeRange(adaptor.getFalseDestOperands()));
     if (failed(convertedFalseBlock))
       return failure();
-    Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
+    auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
         op, adaptor.getCondition(), *convertedTrueBlock,
         adaptor.getTrueDestOperands(), *convertedFalseBlock,
         adaptor.getFalseDestOperands());
+    if (auto weights = op.getBranchWeightsOrNull()) {
+      newOp.setBranchWeights(weights);
+      op.removeBranchWeightsAttr();
+    }
     // TODO: We should not just forward all attributes like that. But there are
     // existing Flang tests that depend on this behavior.
     newOp->setAttrs(op->getAttrDictionary());
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2ae334b517a31..e587e8f1af178 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -80,6 +80,55 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// WeightedBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyBranchWeights(Operation *op) {
+  auto weights = cast<WeightedBranchOpInterface>(op).getBranchWeightsOrNull();
+  if (weights) {
+    if (weights.size() != op->getNumSuccessors())
+      return op->emitError() << "number of weights (" << weights.size()
+                             << ") does not match the number of successors ("
+                             << op->getNumSuccessors() << ")";
+    int32_t total = 0;
+    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+      if (weight.value() < 0)
+        return op->emitError()
+               << "weight #" << weight.index() << " must be non-negative";
+      total += weight.value();
+    }
+    if (total != 100)
+      return op->emitError() << "total weight " << total << " is not 100";
+  }
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// WeightedRegionBranchOpInterface
+//===----------------------------------------------------------------------===//
+
+LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
+  auto weights =
+      cast<WeightedRegionBranchOpInterface>(op).getRegionWeightsOrNull();
+  if (weights) {
+    if (weights.size() != op->getNumRegions())
+      return op->emitError() << "number of weights (" << weights.size()
+                             << ") does not match the number of regions ("
+                             << op->getNumRegions() << ")";
+    int32_t total = 0;
+    for (auto weight : llvm::enumerate(weights.asArrayRef())) {
+      if (weight.value() < 0)
+        return op->emitError()
+               << "weight #" << weight.index() << " must be non-negative";
+      total += weight.value();
+    }
+    if (total != 100)
+      return op->emitError() << "total weight " << total << " is not 100";
+  }
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
index 9a0f2b7714544..7c78211d59010 100644
--- a/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
+++ b/mlir/test/Conversion/ControlFlowToLLVM/branch.mlir
@@ -67,3 +67,17 @@ func.func @unreachable_block() {
 ^bb1(%arg0: index):
   cf.br ^bb1(%arg0 : index)
 }
+
+// -----
+
+// Test case for cf.cond_br with weights.
+
+// CHECK-LABEL:   func.func @cf_cond_br_with_weights(
+func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index {
+// CHECK:           llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64)
+  cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array<i32: 90, 10>}
+^bb1(%arg1: index):
+  return %arg1 : index
+^bb2(%arg2: index):
+  return %arg2 : index
+}
diff --git a/mlir/test/Dialect/ControlFlow/invalid.mlir b/mlir/test/Dialect/ControlFlow/invalid.mlir
index b51d8095c9974..6024c6d55ac64 100644
--- a/mlir/test/Dialect/ControlFlow/invalid.mlir
+++ b/mlir/test/Dialect/ControlFlow/invalid.mlir
@@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
   ^bb3(%bb3arg : i32):
     return
 }
+
+// -----
+
+// CHECK-LABEL: func @wrong_weights_number
+func.func @wrong_weights_number(%cond: i1) {
+  // expected-error at +1 {{number of weights (1) does not match the number of successors (2)}}
+  cf.cond_br %cond weights([100]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @negative_weight
+func.func @wrong_total_weight(%cond: i1) {
+  // expected-error at +1 {{weight #0 must be non-negative}}
+  cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @wrong_total_weight
+func.func @wrong_total_weight(%cond: i1) {
+  // expected-error at +1 {{total weight 101 is not 100}}
+  cf.cond_br %cond weights([100, 1]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}
diff --git a/mlir/test/Dialect/ControlFlow/ops.mlir b/mlir/test/Dialect/ControlFlow/ops.mlir
index c9317c7613972..160534240e0fa 100644
--- a/mlir/test/Dialect/ControlFlow/ops.mlir
+++ b/mlir/test/Dialect/ControlFlow/ops.mlir
@@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) {
   ^bb2:
     return
 }
+
+// CHECK-LABEL: func @cond_weights
+func.func @cond_weights(%cond: i1) {
+// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}}
+  cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2
+  ^bb1:
+    return
+  ^bb2:
+    return
+}



More information about the flang-commits mailing list