[Mlir-commits] [mlir] [mlir][cf] Preserve branch weights during cf.cond_br canonicalization. (PR #144822)

Slava Zakharin llvmlistbot at llvm.org
Wed Jun 18 18:42:27 PDT 2025


https://github.com/vzakhari created https://github.com/llvm/llvm-project/pull/144822

None

>From 7d66f464d8fbf0ce2732da1e6fecd7edbfd73fdf Mon Sep 17 00:00:00 2001
From: Slava Zakharin <szakharin at nvidia.com>
Date: Wed, 18 Jun 2025 18:27:40 -0700
Subject: [PATCH] [mlir][cf] Preserve branch weights during cf.cond_br
 canonicalization.

---
 .../Dialect/ControlFlow/IR/ControlFlowOps.td  | 20 ++++++++++-----
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp |  6 ++---
 .../Dialect/ControlFlow/canonicalize.mlir     | 25 +++++++++++++++++++
 3 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index 79da81ba049dd..a441fd82546e3 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -153,17 +153,25 @@ def CondBranchOp
   let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
                                 "ValueRange":$trueOperands,
                                 "Block *":$falseDest,
-                                "ValueRange":$falseOperands),
+                                "ValueRange":$falseOperands,
+                                CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
                             [{
-      build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
-            falseDest);
+      DenseI32ArrayAttr weights;
+      if (!branchWeights.empty())
+        weights = $_builder.getDenseI32ArrayAttr(branchWeights);
+      build($_builder, $_state, condition, trueOperands, falseOperands,
+            weights, trueDest, falseDest);
     }]>,
                   OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
                                 "Block *":$falseDest,
-                                CArg<"ValueRange", "{}">:$falseOperands),
+                                CArg<"ValueRange", "{}">:$falseOperands,
+                                CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
                             [{
-      build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
-            falseOperands);
+      DenseI32ArrayAttr weights;
+      if (!branchWeights.empty())
+        weights = $_builder.getDenseI32ArrayAttr(branchWeights);
+      build($_builder, $_state, condition, ValueRange(), falseOperands,
+            weights, trueDest, falseDest);
     }]>];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index edd7f607f24f4..0c11c76cf1f71 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -265,9 +265,9 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
       return failure();
 
     // Create a new branch with the collapsed successors.
-    rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
-                                              trueDest, trueDestOperands,
-                                              falseDest, falseDestOperands);
+    rewriter.replaceOpWithNewOp<CondBranchOp>(
+        condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
+        falseDestOperands, condbr.getWeights());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 0ad6898fce86c..bf69935a00bf0 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -102,6 +102,31 @@ func.func @cond_br_and_br_folding(%a : i32) {
 
 /// Test that pass-through successors of CondBranchOp get folded.
 
+// Test that the weights are preserved:
+// CHECK-LABEL:   func.func @cond_br_passthrough_weights(
+// CHECK-SAME:      %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i1) -> i32 {
+func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) -> i32 {
+// CHECK:           cf.cond_br %[[ARG2]] weights([30, 70]), ^bb1, ^bb2
+// CHECK:         ^bb1:
+// CHECK:           return %[[ARG0]] : i32
+// CHECK:         ^bb2:
+// CHECK:           return %[[ARG1]] : i32
+// CHECK:         }
+  cf.cond_br %cond weights([30,70]), ^bb1, ^bb3
+
+^bb1:
+  cf.br ^bb2
+
+^bb3:
+  cf.br ^bb4
+
+^bb2:
+  return %arg0 : i32
+
+^bb4:
+  return %arg1 : i32
+}
+
 // CHECK-LABEL: func @cond_br_passthrough(
 // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
 func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {



More information about the Mlir-commits mailing list